Skip to content

Commit 93a3aaf

Browse files
authored
Remake with symbolic map (#1835)
1 parent 3d04208 commit 93a3aaf

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module ModelingToolkit
55
using DocStringExtensions
66
using AbstractTrees
77
using DiffEqBase, SciMLBase, ForwardDiff, Reexport
8-
using SciMLBase: StandardODEProblem, StandardNonlinearProblem
8+
using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap
99
using Distributed
1010
using StaticArrays, LinearAlgebra, SparseArrays, LabelledArrays
1111
using InteractiveUtils

src/variables.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ applicable.
4646
function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
4747
toterm = Symbolics.diff2term, promotetoconcrete = nothing,
4848
tofloat = true, use_union = false)
49-
varlist = map(unwrap, varlist)
49+
varlist = collect(map(unwrap, varlist))
50+
5051
# Edge cases where one of the arguments is effectively empty.
5152
is_incomplete_initialization = varmap isa DiffEqBase.NullParameters ||
5253
varmap === nothing
@@ -97,7 +98,7 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false
9798
varmap[p] = fixpoint_sub(v, varmap)
9899
end
99100

100-
missingvars = setdiff(varlist, keys(varmap))
101+
missingvars = setdiff(varlist, collect(keys(varmap)))
101102
check && (isempty(missingvars) || throw_missingvars(missingvars))
102103

103104
out = [varmap[var] for var in varlist]
@@ -107,6 +108,17 @@ end
107108
throw(ArgumentError("$vars are missing from the variable map."))
108109
end
109110

111+
"""
112+
$(SIGNATURES)
113+
114+
Intercept the call to `handle_varmap` and convert it to an ordered list if the user has
115+
ModelingToolkit loaded, and the problem has a symbolic origin.
116+
"""
117+
function SciMLBase.handle_varmap(varmap, sys::AbstractSystem; field = :states, kwargs...)
118+
out = varmap_to_vars(varmap, getfield(sys, field); kwargs...)
119+
return out
120+
end
121+
110122
struct IsHistory end
111123
ishistory(x) = ishistory(unwrap(x))
112124
ishistory(x::Symbolic) = getmetadata(x, IsHistory, false)

test/odesystem.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,23 @@ for p in [prob1, prob14]
224224
@test Set(Num.(parameters(sys)) .=> p.p) == Set([k₁ => 0.04, k₂ => 3e7, k₃ => 1e4])
225225
@test Set(Num.(states(sys)) .=> p.u0) == Set([y₁ => 1, y₂ => 0, y₃ => 0])
226226
end
227+
# test remake with symbols
228+
p3 = [k₁ => 0.05,
229+
k₂ => 2e7,
230+
k₃ => 1.1e4]
231+
u01 = [y₁ => 1, y₂ => 1, y₃ => 1]
232+
prob_pmap = remake(prob14; p = p3, u0 = u01)
233+
prob_dpmap = remake(prob14; p = Dict(p3), u0 = Dict(u01))
234+
for p in [prob_pmap, prob_dpmap]
235+
@test Set(Num.(parameters(sys)) .=> p.p) == Set([k₁ => 0.05, k₂ => 2e7, k₃ => 1.1e4])
236+
@test Set(Num.(states(sys)) .=> p.u0) == Set([y₁ => 1, y₂ => 1, y₃ => 1])
237+
end
238+
sol_pmap = solve(prob_pmap, Rodas5())
239+
sol_dpmap = solve(prob_dpmap, Rodas5())
240+
241+
@test sol_pmap.u sol_dpmap.u
242+
243+
# test kwargs
227244
prob2 = ODEProblem(sys, u0, tspan, p, jac = true)
228245
prob3 = ODEProblem(sys, u0, tspan, p, jac = true, sparse = true)
229246
@test prob3.f.jac_prototype isa SparseMatrixCSC

0 commit comments

Comments
 (0)