Skip to content

Commit 6c9d5e6

Browse files
feat: support update_initializeprob!
1 parent f298999 commit 6c9d5e6

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
331331
analytic = nothing,
332332
split_idxs = nothing,
333333
initializeprob = nothing,
334+
update_initializeprob! = nothing,
334335
initializeprobmap = nothing,
335336
initializeprobpmap = nothing,
336337
kwargs...) where {iip, specialize}
@@ -434,6 +435,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
434435
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
435436
analytic = analytic,
436437
initializeprob = initializeprob,
438+
update_initializeprob! = update_initializeprob!,
437439
initializeprobmap = initializeprobmap,
438440
initializeprobpmap = initializeprobpmap)
439441
end
@@ -778,6 +780,17 @@ function (f::GetUpdatedMTKParameters)(prob, initializesol)
778780
mtkp
779781
end
780782

783+
struct UpdateInitializeprob{G, S}
784+
# `getu` functor which gets all values from prob
785+
getvals::G
786+
# `setu` functor which updates initializeprob with values
787+
setvals::S
788+
end
789+
790+
function (f::UpdateInitializeprob)(initializeprob, prob)
791+
f.setvals(initializeprob, f.getvals(prob))
792+
end
793+
781794
function get_temporary_value(p)
782795
stype = symtype(unwrap(p))
783796
return if stype == Real
@@ -866,6 +879,10 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
866879
getpunknowns = getu(initializeprob, punknowns)
867880
setpunknowns = setp(sys, punknowns)
868881
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
882+
reqd_syms = vcat(
883+
variable_symbols(initializeprob), parameter_symbols(initializeprob))
884+
update_initializeprob! = UpdateInitializeprob(
885+
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
869886

870887
zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0)
871888
if parammap isa SciMLBase.NullParameters
@@ -881,6 +898,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
881898
(trueinit = SVector{length(trueinit)}(trueinit))
882899
else
883900
initializeprob = nothing
901+
update_initializeprob! = nothing
884902
initializeprobmap = nothing
885903
initializeprobpmap = nothing
886904
trueinit = u0map
@@ -930,6 +948,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
930948
sparse = sparse, eval_expression = eval_expression,
931949
eval_module = eval_module,
932950
initializeprob = initializeprob,
951+
update_initializeprob! = update_initializeprob!,
933952
initializeprobmap = initializeprobmap,
934953
initializeprobpmap = initializeprobpmap,
935954
kwargs...)

src/systems/nonlinear/initializesystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,5 +200,8 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
200200
getpunknowns = getu(initprob, punknowns)
201201
setpunknowns = setp(sys, punknowns)
202202
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
203-
return initprob, initprobmap, initprobpmap
203+
reqd_syms = vcat(variable_symbols(initprob), parameter_symbols(initprob))
204+
update_initializeprob! = UpdateInitializeprob(
205+
getu(sys, reqd_syms), setu(initprob, reqd_syms))
206+
return initprob, update_initializeprob!, initprobmap, initprobpmap
204207
end

test/initializationsystem.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,3 +704,22 @@ end
704704
prob5 = remake(prob)
705705
@test init(prob, Tsit5()).ps[p] 2.0
706706
end
707+
708+
@testset "Update initializeprob parameters" begin
709+
@variables x(t) y(t)
710+
@parameters p q
711+
@mtkbuild sys = ODESystem(
712+
[D(x) ~ x, p ~ x + y], t; guesses = [x => 0.0, p => 0.0])
713+
prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0), [p => 3.0])
714+
@test prob.f.initializeprob.ps[p] 3.0
715+
@test init(prob, Tsit5())[x] 2.0
716+
prob.ps[p] = 2.0
717+
@test prob.f.initializeprob.ps[p] 3.0
718+
@test init(prob, Tsit5())[x] 1.0
719+
ModelingToolkit.defaults(prob.f.sys)[p] = missing
720+
prob2 = remake(prob; u0 = [y => 1.0], p = [p => 3x])
721+
@test !is_variable(prob2.f.initializeprob, p) &&
722+
!is_parameter(prob2.f.initializeprob, p)
723+
@test init(prob2, Tsit5())[x] 0.5
724+
@test_nowarn solve(prob2, Tsit5())
725+
end

0 commit comments

Comments
 (0)