Skip to content

Commit 6081a50

Browse files
fix: recursively unwrap arrays of symbolics in process_SciMLProblem
1 parent 21d1ce7 commit 6081a50

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

src/systems/problem_utils.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,24 @@ function to_varmap(vals, varlist::Vector)
171171
check_eqs_u0(varlist, varlist, vals)
172172
vals = vec(varlist) .=> vec(vals)
173173
end
174-
return anydict(unwrap(k) => unwrap(v) for (k, v) in anydict(vals))
174+
return recursive_unwrap(anydict(vals))
175+
end
176+
177+
"""
178+
$(TYPEDSIGNATURES)
179+
180+
Recursively call `Symbolics.unwrap` on `x`. Useful when `x` is an array of (potentially)
181+
symbolic values, all of which need to be unwrapped. Specializes when `x isa AbstractDict`
182+
to unwrap keys and values, returning an `AnyDict`.
183+
"""
184+
function recursive_unwrap(x::AbstractArray)
185+
symbolic_type(x) == ArraySymbolic() ? unwrap(x) : recursive_unwrap.(x)
186+
end
187+
188+
recursive_unwrap(x) = unwrap(x)
189+
190+
function recursive_unwrap(x::AbstractDict)
191+
return anydict(unwrap(k) => recursive_unwrap(v) for (k, v) in x)
175192
end
176193

177194
"""
@@ -410,7 +427,7 @@ function process_SciMLProblem(
410427
u0map = to_varmap(u0map, dvs)
411428
_pmap = pmap
412429
pmap = to_varmap(pmap, ps)
413-
defs = add_toterms(defaults(sys))
430+
defs = add_toterms(recursive_unwrap(defaults(sys)))
414431
cmap, cs = get_cmap(sys)
415432
kwargs = NamedTuple(kwargs)
416433

test/initial_values.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,12 @@ end
119119
prob = ODEProblem(sys, [], (1.0, 2.0), [])
120120
@test prob[x] == 1.0
121121
@test prob.ps[p] == 2.0
122+
123+
@testset "Array of symbolics is unwrapped" begin
124+
@variables x(t)[1:2] y(t)
125+
@mtkbuild sys = ODESystem([D(x) ~ x, D(y) ~ t], t; defaults = [x => [y, 3.0]])
126+
prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0))
127+
@test eltype(prob.u0) <: Float64
128+
prob = ODEProblem(sys, [x => [y, 4.0], y => 2.0], (0.0, 1.0))
129+
@test eltype(prob.u0) <: Float64
130+
end

0 commit comments

Comments
 (0)