Skip to content

Commit 0a12d70

Browse files
Merge pull request #3957 from SciML/as/precompile-odeprob
fix: fix `ODEProblem` construction during precompilation with `eval_expression = true`
2 parents 285e84c + 5f99265 commit 0a12d70

File tree

9 files changed

+102
-40
lines changed

9 files changed

+102
-40
lines changed

docs/src/API/codegen.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,5 @@ ModelingToolkit.calculate_A_b
5050
All code generation eventually calls `build_function_wrapper`.
5151

5252
```@docs
53-
build_function_wrapper
53+
ModelingToolkit.build_function_wrapper
5454
```

src/systems/diffeqs/basic_transformations.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,9 +1037,11 @@ function respecialize(sys::AbstractSystem, mapping; all = false)
10371037
"""
10381038

10391039
if iscall(k)
1040-
op = operation(k)
1040+
op = operation(k)::BasicSymbolic
1041+
@assert !iscall(op)
1042+
op = SymbolicUtils.Sym{SymbolicUtils.FnType{Tuple{Any}, T}}(nameof(op))
10411043
args = arguments(k)
1042-
new_p = SymbolicUtils.term(op, args...; type = T)
1044+
new_p = op(args...)
10431045
else
10441046
new_p = SymbolicUtils.Sym{T}(getname(k))
10451047
end

src/systems/parameter_buffer.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -763,8 +763,14 @@ function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = tru
763763
oldbuf.discrete, newbuf.discrete)
764764
@set! newbuf.constant = narrow_buffer_type_and_fallback_undefs.(
765765
oldbuf.constant, newbuf.constant)
766-
@set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.(
767-
oldbuf.nonnumeric, newbuf.nonnumeric)
766+
for (oldv, newv) in zip(oldbuf.nonnumeric, newbuf.nonnumeric)
767+
for i in eachindex(oldv)
768+
isassigned(newv, i) && continue
769+
newv[i] = oldv[i]
770+
end
771+
end
772+
@set! newbuf.nonnumeric = Tuple(
773+
typeof(oldv)(newv) for (oldv, newv) in zip(oldbuf.nonnumeric, newbuf.nonnumeric))
768774
if !ArrayInterface.ismutable(oldbuf)
769775
@set! newbuf.tunable = similar_type(oldbuf.tunable, eltype(newbuf.tunable))(newbuf.tunable)
770776
@set! newbuf.initials = similar_type(oldbuf.initials, eltype(newbuf.initials))(newbuf.initials)
@@ -820,7 +826,7 @@ function SciMLBase.create_parameter_timeseries_collection(
820826
isempty(ps.discrete) && return nothing
821827
num_discretes = only(blocksize(ps.discrete[1]))
822828
buffers = []
823-
partition_type = Tuple{(Vector{eltype(buf)} for buf in ps.discrete)...}
829+
partition_type = Tuple{(typeof(parent(buf)) for buf in ps.discrete)...}
824830
for i in 1:num_discretes
825831
ts = eltype(tspan)[]
826832
us = NestedGetIndex{partition_type}[]

src/systems/problem_utils.jl

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -701,9 +701,10 @@ function.
701701
Note that the getter ONLY works for problem-like objects, since it generates an observed
702702
function. It does NOT work for solutions.
703703
"""
704-
Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector)
704+
Base.@nospecializeinfer function concrete_getu(indp, syms; eval_expression, eval_module)
705705
@nospecialize
706-
obsfn = build_explicit_observed_function(indp, syms; wrap_delays = false)
706+
obsfn = build_explicit_observed_function(
707+
indp, syms; wrap_delays = false, eval_expression, eval_module)
707708
return ObservedWrapper{is_time_dependent(indp)}(obsfn)
708709
end
709710

@@ -757,7 +758,8 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
757758
- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
758759
"""
759760
function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem;
760-
initials = false, unwrap_initials = false, p_constructor = identity)
761+
initials = false, unwrap_initials = false, p_constructor = identity,
762+
eval_expression = false, eval_module = @__MODULE__)
761763
_p_constructor = p_constructor
762764
p_constructor = PConstructorApplicator(p_constructor)
763765
# if we call `getu` on this (and it were able to handle empty tuples) we get the
@@ -773,7 +775,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
773775
tunable_getter = if isempty(tunable_syms)
774776
Returns(SizedVector{0, Float64}())
775777
else
776-
p_constructor concrete_getu(srcsys, tunable_syms)
778+
p_constructor concrete_getu(srcsys, tunable_syms; eval_expression, eval_module)
777779
end
778780
initials_getter = if initials && !isempty(syms[2])
779781
initsyms = Vector{Any}(syms[2])
@@ -792,7 +794,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
792794
end
793795
end
794796
end
795-
p_constructor concrete_getu(srcsys, initsyms)
797+
p_constructor concrete_getu(srcsys, initsyms; eval_expression, eval_module)
796798
else
797799
Returns(SizedVector{0, Float64}())
798800
end
@@ -810,7 +812,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
810812
# tuple of `BlockedArray`s
811813
Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes)
812814
Base.Fix1(broadcast, p_constructor)
813-
getu(srcsys, syms[3])
815+
concrete_getu(srcsys, syms[3]; eval_expression, eval_module)
814816
end
815817
const_getter = if syms[4] == ()
816818
Returns(())
@@ -826,7 +828,8 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
826828
end)
827829
# nonnumerics retain the assigned buffer type without narrowing
828830
Base.Fix1(broadcast, _p_constructor)
829-
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) getu(srcsys, syms[5])
831+
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes)
832+
concrete_getu(srcsys, syms[5]; eval_expression, eval_module)
830833
end
831834
getters = (
832835
tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter)
@@ -853,14 +856,19 @@ Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of `
853856
with values from `srcsys`.
854857
"""
855858
function ReconstructInitializeprob(
856-
srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity)
859+
srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity,
860+
eval_expression = false, eval_module = @__MODULE__)
857861
@assert is_initializesystem(dstsys)
858-
ugetter = u0_constructor getu(srcsys, unknowns(dstsys))
862+
ugetter = u0_constructor
863+
concrete_getu(srcsys, unknowns(dstsys); eval_expression, eval_module)
859864
if is_split(dstsys)
860-
pgetter = get_mtkparameters_reconstructor(srcsys, dstsys; p_constructor)
865+
pgetter = get_mtkparameters_reconstructor(
866+
srcsys, dstsys; p_constructor, eval_expression, eval_module)
861867
else
862868
syms = parameters(dstsys)
863-
pgetter = let inner = concrete_getu(srcsys, syms), p_constructor = p_constructor
869+
pgetter = let inner = concrete_getu(srcsys, syms; eval_expression, eval_module),
870+
p_constructor = p_constructor
871+
864872
function _getter2(valp, initprob)
865873
p_constructor(inner(valp))
866874
end
@@ -924,18 +932,20 @@ Given `sys` and its corresponding initialization system `initsys`, return the
924932
`initializeprobpmap` function in `OverrideInitData` for the systems.
925933
"""
926934
function construct_initializeprobpmap(
927-
sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity)
935+
sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity, eval_expression, eval_module)
928936
@assert is_initializesystem(initsys)
929937
if is_split(sys)
930938
return let getter = get_mtkparameters_reconstructor(
931-
initsys, sys; initials = true, unwrap_initials = true, p_constructor)
939+
initsys, sys; initials = true, unwrap_initials = true, p_constructor,
940+
eval_expression, eval_module)
932941
function initprobpmap_split(prob, initsol)
933942
getter(initsol, prob)
934943
end
935944
end
936945
else
937-
return let getter = getu(initsys, parameters(sys; initial_parameters = true)),
938-
p_constructor = p_constructor
946+
return let getter = concrete_getu(
947+
initsys, parameters(sys; initial_parameters = true);
948+
eval_expression, eval_module), p_constructor = p_constructor
939949

940950
function initprobpmap_nosplit(prob, initsol)
941951
return p_constructor(getter(initsol))
@@ -1039,14 +1049,14 @@ struct GetUpdatedU0{GG, GIU}
10391049
get_initial_unknowns::GIU
10401050
end
10411051

1042-
function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict)
1052+
function GetUpdatedU0(sys::AbstractSystem, initprob::SciMLBase.AbstractNonlinearProblem, op::AbstractDict)
10431053
dvs = unknowns(sys)
10441054
eqs = equations(sys)
10451055
guessvars = trues(length(dvs))
10461056
for (i, var) in enumerate(dvs)
10471057
guessvars[i] = !isequal(get(op, var, nothing), Initial(var))
10481058
end
1049-
get_guessvars = getu(initsys, dvs[guessvars])
1059+
get_guessvars = getu(initprob, dvs[guessvars])
10501060
get_initial_unknowns = getu(sys, Initial.(dvs))
10511061
return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns)
10521062
end
@@ -1108,7 +1118,7 @@ function maybe_build_initialization_problem(
11081118
guesses, missing_unknowns; implicit_dae = false,
11091119
time_dependent_init = is_time_dependent(sys), u0_constructor = identity,
11101120
p_constructor = identity, floatT = Float64, initialization_eqs = [],
1111-
use_scc = true, kwargs...)
1121+
use_scc = true, eval_expression = false, eval_module = @__MODULE__, kwargs...)
11121122
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))
11131123

11141124
if t === nothing && is_time_dependent(sys)
@@ -1117,7 +1127,7 @@ function maybe_build_initialization_problem(
11171127

11181128
initializeprob = ModelingToolkit.InitializationProblem{iip}(
11191129
sys, t, op; guesses, time_dependent_init, initialization_eqs,
1120-
use_scc, u0_constructor, p_constructor, kwargs...)
1130+
use_scc, u0_constructor, p_constructor, eval_expression, eval_module, kwargs...)
11211131
if state_values(initializeprob) !== nothing
11221132
_u0 = state_values(initializeprob)
11231133
if ArrayInterface.ismutable(_u0)
@@ -1145,15 +1155,16 @@ function maybe_build_initialization_problem(
11451155
initializeprob = remake(initializeprob; p = initp)
11461156

11471157
get_initial_unknowns = if time_dependent_init
1148-
GetUpdatedU0(sys, initializeprob.f.sys, op)
1158+
GetUpdatedU0(sys, initializeprob, op)
11491159
else
11501160
nothing
11511161
end
11521162
meta = InitializationMetadata(
11531163
copy(op), copy(guesses), Vector{Equation}(initialization_eqs),
11541164
use_scc, time_dependent_init,
11551165
ReconstructInitializeprob(
1156-
sys, initializeprob.f.sys; u0_constructor, p_constructor),
1166+
sys, initializeprob.f.sys; u0_constructor,
1167+
p_constructor, eval_expression, eval_module),
11571168
get_initial_unknowns, SetInitialUnknowns(sys))
11581169

11591170
if time_dependent_init
@@ -1172,10 +1183,9 @@ function maybe_build_initialization_problem(
11721183
initializeprobpmap = nothing
11731184
else
11741185
initializeprobpmap = construct_initializeprobpmap(
1175-
sys, initializeprob.f.sys; p_constructor)
1186+
sys, initializeprob.f.sys; p_constructor, eval_expression, eval_module)
11761187
end
11771188

1178-
reqd_syms = parameter_symbols(initializeprob)
11791189
# we still want the `initialization_data` because it helps with `remake`
11801190
if initializeprobmap === nothing && initializeprobpmap === nothing
11811191
update_initializeprob! = nothing
@@ -1186,7 +1196,9 @@ function maybe_build_initialization_problem(
11861196
filter!(punknowns) do p
11871197
is_parameter_solvable(p, op, defs, guesses) && get(op, p, missing) === missing
11881198
end
1189-
pvals = getu(initializeprob, punknowns)(initializeprob)
1199+
# See comment below for why `getu` is not used here.
1200+
_pgetter = build_explicit_observed_function(initializeprob.f.sys, punknowns)
1201+
pvals = _pgetter(state_values(initializeprob), parameter_values(initializeprob))
11901202
for (p, pval) in zip(punknowns, pvals)
11911203
p = unwrap(p)
11921204
op[p] = pval
@@ -1198,7 +1210,13 @@ function maybe_build_initialization_problem(
11981210
end
11991211

12001212
if time_dependent_init
1201-
uvals = getu(initializeprob, collect(missing_unknowns))(initializeprob)
1213+
# We can't use `getu` here because that goes to `SII.observed`, which goes to
1214+
# `ObservedFunctionCache` which uses `eval_expression` and `eval_module`. If
1215+
# `eval_expression == true`, this then runs into world-age issues. Building an
1216+
# RGF here is fine since it is always discarded. We can't use `eval_module` for
1217+
# the RGF since the user may not have run RGF's init.
1218+
_ugetter = build_explicit_observed_function(initializeprob.f.sys, collect(missing_unknowns))
1219+
uvals = _ugetter(state_values(initializeprob), parameter_values(initializeprob))
12021220
for (v, val) in zip(missing_unknowns, uvals)
12031221
op[v] = val
12041222
end
@@ -1461,7 +1479,7 @@ function process_SciMLProblem(
14611479
if is_time_dependent(sys) && t0 === nothing
14621480
t0 = zero(floatT)
14631481
end
1464-
initialization_data = SciMLBase.remake_initialization_data(
1482+
initialization_data = @invokelatest SciMLBase.remake_initialization_data(
14651483
sys, kwargs, u0, t0, p, u0, p)
14661484
kwargs = merge(kwargs, (; initialization_data))
14671485
end
@@ -1773,7 +1791,8 @@ Construct SciMLProblem `T` with positional arguments `args` and keywords `kwargs
17731791
"""
17741792
function maybe_codegen_scimlproblem(::Type{Val{false}}, T, args::NamedTuple; kwargs...)
17751793
# Call `remake` so it runs initialization if it is trivial
1776-
remake(T(args...; kwargs...))
1794+
# Use `@invokelatest` to avoid world-age issues with `eval_expression = true`
1795+
@invokelatest remake(T(args...; kwargs...))
17771796
end
17781797

17791798
"""

test/basic_transformations.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,12 @@ foofn(x) = 4
340340

341341
@testset "`respecialize`" begin
342342
@parameters p::AbstractFoo p2(t)::AbstractFoo = p q[1:2]::AbstractFoo r
343-
rp,
344-
rp2 = let
345-
only(@parameters p::Bar),
346-
SymbolicUtils.term(operation(p2), arguments(p2)...; type = Baz)
347-
end
343+
rp = only(let p = nothing
344+
@parameters p::Bar
345+
end)
346+
rp2 = only(let p2 = nothing
347+
@parameters p2(t)::Baz
348+
end)
348349
@variables x(t) = 1.0
349350
@named sys1 = System([D(x) ~ foofn(p) + foofn(p2) + x], t, [x], [p, p2, q, r])
350351

test/initializationsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ end
826826
@test integ.ps[param]val rtol=1e-5
827827
# some algorithms are a little temperamental
828828
sol = solve(prob, alg)
829-
@test sol.ps[param]val rtol=1e-5
829+
@test sol.ps[param]val rtol=1e-5 broken=(alg===SimpleNewtonRaphson())
830830
@test SciMLBase.successful_retcode(sol)
831831
end
832832

test/mtkparameters.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ ps = MTKParameters(
357357
(BlockedArray([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [3, 3]),
358358
BlockedArray(falses(1), [1, 0])),
359359
(), (), ())
360-
@test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, Vector{Bool}}
360+
@test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, BitVector}
361361
tsidx1 = 1
362362
tsidx2 = 2
363363
@test length(ps.discrete[1][Block(tsidx1)]) == 3
@@ -368,3 +368,14 @@ with_updated_parameter_timeseries_values(
368368
sys, ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false])))
369369
@test ps.discrete[1][Block(tsidx1)] == [10.0, 11.0, 12.0]
370370
@test ps.discrete[2][Block(tsidx1)][] == false
371+
372+
@testset "Avoid specialization of nonnumeric parameters on `remake_buffer`" begin
373+
@variables x(t)
374+
@parameters p::Any
375+
@named sys = System(D(x) ~ x, t, [x], [p])
376+
sys = complete(sys)
377+
ps = MTKParameters(sys, [p => 1.0])
378+
@test ps.nonnumeric isa Tuple{Vector{Any}}
379+
ps2 = remake_buffer(sys, ps, [p], [:a])
380+
@test ps2.nonnumeric isa Tuple{Vector{Any}}
381+
end

test/precompile_test.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Test
22
using ModelingToolkit
3+
using OrdinaryDiffEqDefault
34

45
using Distributed
56

@@ -38,3 +39,5 @@ ODEPrecompileTest.f_eval_bad(u, p, 0.1)
3839
@test parentmodule(typeof(ODEPrecompileTest.f_eval_good.f.f_oop)) ==
3940
ODEPrecompileTest
4041
@test ODEPrecompileTest.f_eval_good(u, p, 0.1) == [4, 0, -16]
42+
43+
@test_nowarn solve(ODEPrecompileTest.prob_eval)

test/precompile_test/ODEPrecompileTest.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,24 @@ const f_eval_bad = system(; eval_expression = true, eval_module = @__MODULE__)
3636
# Change the module the eval'd function is eval'd into to be the containing module,
3737
# which should make it be in the package image
3838
const f_eval_good = system(; eval_expression = true, eval_module = @__MODULE__)
39+
40+
function problem(; kwargs...)
41+
# Define some variables
42+
@independent_variables t
43+
@parameters σ ρ β
44+
@variables x(t) y(t) z(t)
45+
D = Differential(t)
46+
47+
# Define a differential equation
48+
eqs = [D(x) ~ σ * (y - x),
49+
D(y) ~ x *- z) - y,
50+
D(z) ~ x * y - β * z]
51+
52+
@named de = System(eqs, t)
53+
de = complete(de)
54+
return ODEProblem(de, [x => 1, y => 0, z => 0, σ => 10, ρ => 28, β => 8/3], (0.0, 5.0); kwargs...)
55+
end
56+
57+
const prob_eval = problem(; eval_expression = true, eval_module = @__MODULE__)
58+
3959
end

0 commit comments

Comments
 (0)