Skip to content

Commit ccfed66

Browse files
fixup! feat: support unknown parameters during initialization
1 parent 41fe639 commit ccfed66

File tree

2 files changed

+40
-34
lines changed

2 files changed

+40
-34
lines changed

src/systems/abstractsystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1849,14 +1849,17 @@ function linearization_function(sys::AbstractSystem, inputs,
18491849
end
18501850
initfn = NonlinearFunction(initsys)
18511851
initprobmap = getu(initsys, unknowns(sys))
1852+
initprob_init! = generate_initializeprob_init(sys, initsys)
1853+
initprob_update! = generate_initializeprob_update(sys, initsys)
18521854
ps = full_parameters(sys)
18531855
lin_fun = let diff_idxs = diff_idxs,
18541856
alge_idxs = alge_idxs,
18551857
input_idxs = input_idxs,
18561858
sts = unknowns(sys),
18571859
get_initprob_u_p = get_initprob_u_p,
18581860
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
1859-
sys, unknowns(sys), ps; initializeprobmap = initprobmap),
1861+
sys, unknowns(sys), ps; initializeprob_init! = initprob_init!,
1862+
initializeprob_update! = initprob_update!),
18601863
initfn = initfn,
18611864
h = build_explicit_observed_function(sys, outputs),
18621865
chunk = ForwardDiff.Chunk(input_idxs),

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,25 @@ function isautonomous(sys::AbstractODESystem)
280280
all(iszero, tgrad)
281281
end
282282

283+
struct GetAndSetFunctor{G, S}
284+
getter::G
285+
setter::S
286+
end
287+
288+
function (gs::GetAndSetFunctor)(dest, source)
289+
gs.setter(dest, gs.getter(source))
290+
end
291+
292+
function generate_initializeprob_init(sys::AbstractSystem, initsys::AbstractSystem)
293+
syms = vcat(variable_symbols(initsys), parameter_symbols(initsys))
294+
return GetAndSetFunctor(getu(sys, syms), setu(initsys, syms))
295+
end
296+
297+
function generate_initializeprob_update(sys::AbstractSystem, initsys::AbstractSystem)
298+
syms = vcat(variable_symbols(sys), parameter_symbols(sys))
299+
return GetAndSetFunctor(getu(initsys, syms), setu(sys, syms))
300+
end
301+
283302
"""
284303
```julia
285304
DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
@@ -323,8 +342,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
323342
analytic = nothing,
324343
split_idxs = nothing,
325344
initializeprob = nothing,
326-
initializeprobmap = nothing,
327-
initializeprob_updatep! = nothing,
345+
initializeprob_init! = nothing,
346+
initializeprob_update! = nothing,
328347
kwargs...) where {iip, specialize}
329348
if !iscomplete(sys)
330349
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
@@ -507,8 +526,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
507526
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
508527
analytic = analytic,
509528
initializeprob = initializeprob,
510-
initializeprobmap = initializeprobmap,
511-
initializeprob_updatep! = initializeprob_updatep!)
529+
initializeprob_init! = initializeprob_init!,
530+
initializeprob_update! = initializeprob_update!)
512531
end
513532

514533
"""
@@ -539,8 +558,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
539558
eval_module = @__MODULE__,
540559
checkbounds = false,
541560
initializeprob = nothing,
542-
initializeprobmap = nothing,
543-
initializeprob_updatep! = nothing,
561+
initializeprob_init! = nothing,
562+
initializeprob_update! = nothing,
544563
kwargs...) where {iip}
545564
if !iscomplete(sys)
546565
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
@@ -614,8 +633,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
614633
jac_prototype = jac_prototype,
615634
observed = observedfun,
616635
initializeprob = initializeprob,
617-
initializeprobmap = initializeprobmap,
618-
initializeprob_updatep! = initializeprob_updatep!)
636+
initializeprob_init! = initializeprob_init!,
637+
initializeprob_update! = initializeprob_update!)
619638
end
620639

621640
function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
@@ -927,26 +946,11 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
927946
end
928947
initializeprob = ModelingToolkit.InitializationProblem(
929948
sys, t, u0map, parammap; guesses, warn_initialize_determined)
930-
unks = unknowns(sys)
931-
initializeprobmap = isempty(unks) ? (_...) -> nothing :
932-
getu(initializeprob, unknowns(sys))
933-
if any(p -> is_variable(initializeprob, p) || is_observed(initializeprob, p),
934-
parameters(sys))
935-
punknowns = [p
936-
for p in parameters(sys)
937-
if is_variable(initializeprob, p) ||
938-
is_observed(initializeprob, p)]
939-
initializeprob_updatep! = let getter = getu(initializeprob, tovar.(punknowns)),
940-
setter = setp(sys, punknowns)
941-
942-
function (ps, initsol)
943-
setter(ps, getter(initsol))
944-
end
945-
end
946-
else
947-
punknowns = []
948-
initializeprob_updatep! = nothing
949-
end
949+
punknowns = [p
950+
for p in parameters(sys)
951+
if is_variable(initializeprob, p) || is_observed(initializeprob, p)]
952+
initializeprob_init! = generate_initializeprob_init(sys, initializeprob.f.sys)
953+
initializeprob_update! = generate_initializeprob_update(sys, initializeprob.f.sys)
950954
zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0)
951955
zeropars = Dict()
952956
for p in punknowns
@@ -961,9 +965,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
961965
(trueinit = SVector{length(trueinit)}(trueinit))
962966
else
963967
initializeprob = nothing
964-
initializeprobmap = nothing
965-
initializeprob_updatep! = nothing
966968
zeropars = Dict()
969+
initializeprob_init! = nothing
970+
initializeprob_update! = nothing
967971
trueinit = u0map
968972
end
969973

@@ -1012,9 +1016,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
10121016
checkbounds = checkbounds, p = p,
10131017
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
10141018
sparse = sparse, eval_expression = eval_expression,
1015-
initializeprob = initializeprob,
1016-
initializeprobmap = initializeprobmap,
1017-
initializeprob_updatep! = initializeprob_updatep!,
1019+
initializeprob = initializeprob, initializeprob_init! = initializeprob_init!,
1020+
initializeprob_update! = initializeprob_update!,
10181021
kwargs...)
10191022
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
10201023
end

0 commit comments

Comments
 (0)