Skip to content

Commit 72f7743

Browse files
Merge pull request #1941 from ValentinKaisermayer/vk-implement-process_p_u0_symbolic
Implements process_p_u0_symbolic from SciMLBase
2 parents 6e49923 + 34202b7 commit 72f7743

File tree

3 files changed

+57
-6
lines changed

3 files changed

+57
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ NonlinearSolve = "0.3.8"
7272
RecursiveArrayTools = "2.3"
7373
Reexport = "0.2, 1"
7474
RuntimeGeneratedFunctions = "0.4.3, 0.5"
75-
SciMLBase = "1.70.0"
75+
SciMLBase = "1.72.0"
7676
Setfield = "0.7, 0.8, 1"
7777
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
7878
StaticArrays = "0.10, 0.11, 0.12, 1.0"

src/variables.jl

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,42 @@ end
119119
throw(ArgumentError("$vars are missing from the variable map."))
120120
end
121121

122+
# FIXME: remove after: https://github.com/SciML/SciMLBase.jl/pull/311
123+
function SciMLBase.handle_varmap(varmap, sys::AbstractSystem; field = :states, kwargs...)
124+
out = varmap_to_vars(varmap, getfield(sys, field); kwargs...)
125+
return out
126+
end
127+
122128
"""
123129
$(SIGNATURES)
124130
125-
Intercept the call to `handle_varmap` and convert it to an ordered list if the user has
126-
ModelingToolkit loaded, and the problem has a symbolic origin.
131+
Intercept the call to `process_p_u0_symbolic` and process symbolic maps of `p` and/or `u0` if the
132+
user has `ModelingToolkit` loaded.
127133
"""
128-
function SciMLBase.handle_varmap(varmap, sys::AbstractSystem; field = :states, kwargs...)
129-
out = varmap_to_vars(varmap, getfield(sys, field); kwargs...)
130-
return out
134+
function SciMLBase.process_p_u0_symbolic(prob::ODEProblem, p, u0)
135+
# check if a symbolic remake is possible
136+
if eltype(p) <: Pair
137+
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :ps) ||
138+
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
139+
" Please use `remake` with the `p` keyword argument as a vector of values, paying attention to parameter order."))
140+
end
141+
if eltype(u0) <: Pair
142+
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :states) ||
143+
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
144+
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to state order."))
145+
end
146+
147+
# assemble defaults
148+
defs = defaults(prob.f.sys)
149+
defs = mergedefaults(defs, prob.p, parameters(prob.f.sys))
150+
defs = mergedefaults(defs, p, parameters(prob.f.sys))
151+
defs = mergedefaults(defs, prob.u0, states(prob.f.sys))
152+
defs = mergedefaults(defs, u0, states(prob.f.sys))
153+
154+
u0 = varmap_to_vars(u0, states(prob.f.sys); defaults = defs, tofloat = true)
155+
p = varmap_to_vars(p, parameters(prob.f.sys); defaults = defs)
156+
157+
return p, u0
131158
end
132159

133160
struct IsHistory end

test/odesystem.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,30 @@ sol_dpmap = solve(prob_dpmap, Rodas5())
257257

258258
@test sol_pmap.u sol_dpmap.u
259259

260+
@testset "symbolic remake with nested system" begin
261+
function makesys(name)
262+
@parameters t a=1.0
263+
@variables x(t) = 0.0
264+
D = Differential(t)
265+
ODESystem([D(x) ~ -a * x]; name)
266+
end
267+
268+
function makecombinedsys()
269+
sys1 = makesys(:sys1)
270+
sys2 = makesys(:sys2)
271+
@parameters t b=1.0
272+
ODESystem(Equation[], t, [], [b]; systems = [sys1, sys2], name = :foo)
273+
end
274+
275+
sys = makecombinedsys()
276+
@unpack sys1, b = sys
277+
prob = ODEProblem(sys, Pair[])
278+
prob_new = SciMLBase.remake(prob, p = Dict(sys1.a => 3.0, b => 4.0),
279+
u0 = Dict(sys1.x => 1.0))
280+
@test_broken prob_new.p == [4.0, 3.0, 1.0]
281+
@test_broken prob_new.u0 == [1.0, 0.0]
282+
end
283+
260284
# test kwargs
261285
prob2 = ODEProblem(sys, u0, tspan, p, jac = true)
262286
prob3 = ODEProblem(sys, u0, tspan, p, jac = true, sparse = true) #SparseMatrixCSC need to handle

0 commit comments

Comments
 (0)