Skip to content

Commit 359fea7

Browse files
fix: fix eval_expression = true construction of problems
1 parent 52c231f commit 359fea7

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

src/systems/problem_utils.jl

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -701,9 +701,10 @@ function.
701701
Note that the getter ONLY works for problem-like objects, since it generates an observed
702702
function. It does NOT work for solutions.
703703
"""
704-
Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector)
704+
Base.@nospecializeinfer function concrete_getu(indp, syms; eval_expression, eval_module)
705705
@nospecialize
706-
obsfn = build_explicit_observed_function(indp, syms; wrap_delays = false)
706+
obsfn = build_explicit_observed_function(
707+
indp, syms; wrap_delays = false, eval_expression, eval_module)
707708
return ObservedWrapper{is_time_dependent(indp)}(obsfn)
708709
end
709710

@@ -757,7 +758,8 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
757758
- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
758759
"""
759760
function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem;
760-
initials = false, unwrap_initials = false, p_constructor = identity)
761+
initials = false, unwrap_initials = false, p_constructor = identity,
762+
eval_expression = false, eval_module = @__MODULE__)
761763
_p_constructor = p_constructor
762764
p_constructor = PConstructorApplicator(p_constructor)
763765
# if we call `getu` on this (and it were able to handle empty tuples) we get the
@@ -773,7 +775,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
773775
tunable_getter = if isempty(tunable_syms)
774776
Returns(SizedVector{0, Float64}())
775777
else
776-
p_constructor concrete_getu(srcsys, tunable_syms)
778+
p_constructor concrete_getu(srcsys, tunable_syms; eval_expression, eval_module)
777779
end
778780
initials_getter = if initials && !isempty(syms[2])
779781
initsyms = Vector{Any}(syms[2])
@@ -792,7 +794,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
792794
end
793795
end
794796
end
795-
p_constructor concrete_getu(srcsys, initsyms)
797+
p_constructor concrete_getu(srcsys, initsyms; eval_expression, eval_module)
796798
else
797799
Returns(SizedVector{0, Float64}())
798800
end
@@ -810,7 +812,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
810812
# tuple of `BlockedArray`s
811813
Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes)
812814
Base.Fix1(broadcast, p_constructor)
813-
getu(srcsys, syms[3])
815+
concrete_getu(srcsys, syms[3]; eval_expression, eval_module)
814816
end
815817
const_getter = if syms[4] == ()
816818
Returns(())
@@ -826,7 +828,8 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
826828
end)
827829
# nonnumerics retain the assigned buffer type without narrowing
828830
Base.Fix1(broadcast, _p_constructor)
829-
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) getu(srcsys, syms[5])
831+
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes)
832+
concrete_getu(srcsys, syms[5]; eval_expression, eval_module)
830833
end
831834
getters = (
832835
tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter)
@@ -853,14 +856,19 @@ Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of `
853856
with values from `srcsys`.
854857
"""
855858
function ReconstructInitializeprob(
856-
srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity)
859+
srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity,
860+
eval_expression = false, eval_module = @__MODULE__)
857861
@assert is_initializesystem(dstsys)
858-
ugetter = u0_constructor getu(srcsys, unknowns(dstsys))
862+
ugetter = u0_constructor
863+
concrete_getu(srcsys, unknowns(dstsys); eval_expression, eval_module)
859864
if is_split(dstsys)
860-
pgetter = get_mtkparameters_reconstructor(srcsys, dstsys; p_constructor)
865+
pgetter = get_mtkparameters_reconstructor(
866+
srcsys, dstsys; p_constructor, eval_expression, eval_module)
861867
else
862868
syms = parameters(dstsys)
863-
pgetter = let inner = concrete_getu(srcsys, syms), p_constructor = p_constructor
869+
pgetter = let inner = concrete_getu(srcsys, syms; eval_expression, eval_module),
870+
p_constructor = p_constructor
871+
864872
function _getter2(valp, initprob)
865873
p_constructor(inner(valp))
866874
end
@@ -924,17 +932,19 @@ Given `sys` and its corresponding initialization system `initsys`, return the
924932
`initializeprobpmap` function in `OverrideInitData` for the systems.
925933
"""
926934
function construct_initializeprobpmap(
927-
sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity)
935+
sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity, eval_expression, eval_module)
928936
@assert is_initializesystem(initsys)
929937
if is_split(sys)
930938
return let getter = get_mtkparameters_reconstructor(
931-
initsys, sys; initials = true, unwrap_initials = true, p_constructor)
939+
initsys, sys; initials = true, unwrap_initials = true, p_constructor,
940+
eval_expression, eval_module)
932941
function initprobpmap_split(prob, initsol)
933942
getter(initsol, prob)
934943
end
935944
end
936945
else
937-
return let getter = getu(initsys, parameters(sys; initial_parameters = true)),
946+
return let getter = concrete_getu(initsys, parameters(sys; initial_parameters = true);
947+
eval_expression, eval_module),
938948
p_constructor = p_constructor
939949

940950
function initprobpmap_nosplit(prob, initsol)
@@ -1039,14 +1049,14 @@ struct GetUpdatedU0{GG, GIU}
10391049
get_initial_unknowns::GIU
10401050
end
10411051

1042-
function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict)
1052+
function GetUpdatedU0(sys::AbstractSystem, initprob::SciMLBase.AbstractNonlinearProblem, op::AbstractDict)
10431053
dvs = unknowns(sys)
10441054
eqs = equations(sys)
10451055
guessvars = trues(length(dvs))
10461056
for (i, var) in enumerate(dvs)
10471057
guessvars[i] = !isequal(get(op, var, nothing), Initial(var))
10481058
end
1049-
get_guessvars = getu(initsys, dvs[guessvars])
1059+
get_guessvars = getu(initprob, dvs[guessvars])
10501060
get_initial_unknowns = getu(sys, Initial.(dvs))
10511061
return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns)
10521062
end
@@ -1108,7 +1118,7 @@ function maybe_build_initialization_problem(
11081118
guesses, missing_unknowns; implicit_dae = false,
11091119
time_dependent_init = is_time_dependent(sys), u0_constructor = identity,
11101120
p_constructor = identity, floatT = Float64, initialization_eqs = [],
1111-
use_scc = true, kwargs...)
1121+
use_scc = true, eval_expression = false, eval_module = @__MODULE__, kwargs...)
11121122
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))
11131123

11141124
if t === nothing && is_time_dependent(sys)
@@ -1117,7 +1127,7 @@ function maybe_build_initialization_problem(
11171127

11181128
initializeprob = ModelingToolkit.InitializationProblem{iip}(
11191129
sys, t, op; guesses, time_dependent_init, initialization_eqs,
1120-
use_scc, u0_constructor, p_constructor, kwargs...)
1130+
use_scc, u0_constructor, p_constructor, eval_expression, eval_module, kwargs...)
11211131
if state_values(initializeprob) !== nothing
11221132
_u0 = state_values(initializeprob)
11231133
if ArrayInterface.ismutable(_u0)
@@ -1145,15 +1155,16 @@ function maybe_build_initialization_problem(
11451155
initializeprob = remake(initializeprob; p = initp)
11461156

11471157
get_initial_unknowns = if time_dependent_init
1148-
GetUpdatedU0(sys, initializeprob.f.sys, op)
1158+
GetUpdatedU0(sys, initializeprob, op)
11491159
else
11501160
nothing
11511161
end
11521162
meta = InitializationMetadata(
11531163
copy(op), copy(guesses), Vector{Equation}(initialization_eqs),
11541164
use_scc, time_dependent_init,
11551165
ReconstructInitializeprob(
1156-
sys, initializeprob.f.sys; u0_constructor, p_constructor),
1166+
sys, initializeprob.f.sys; u0_constructor,
1167+
p_constructor, eval_expression, eval_module),
11571168
get_initial_unknowns, SetInitialUnknowns(sys))
11581169

11591170
if time_dependent_init
@@ -1172,7 +1183,7 @@ function maybe_build_initialization_problem(
11721183
initializeprobpmap = nothing
11731184
else
11741185
initializeprobpmap = construct_initializeprobpmap(
1175-
sys, initializeprob.f.sys; p_constructor)
1186+
sys, initializeprob.f.sys; p_constructor, eval_expression, eval_module)
11761187
end
11771188

11781189
# we still want the `initialization_data` because it helps with `remake`
@@ -1468,7 +1479,7 @@ function process_SciMLProblem(
14681479
if is_time_dependent(sys) && t0 === nothing
14691480
t0 = zero(floatT)
14701481
end
1471-
initialization_data = SciMLBase.remake_initialization_data(
1482+
initialization_data = @invokelatest SciMLBase.remake_initialization_data(
14721483
sys, kwargs, u0, t0, p, u0, p)
14731484
kwargs = merge(kwargs, (; initialization_data))
14741485
end

0 commit comments

Comments
 (0)