Skip to content

Commit 2641fd8

Browse files
fix: handle initial values passed as Symbols
1 parent 8d4f542 commit 2641fd8

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

src/systems/problem_utils.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,42 @@ function add_toterms(varmap::AbstractDict; toterm = default_toterm)
5252
return cp
5353
end
5454

55+
"""
56+
$(TYPEDSIGNATURES)
57+
58+
Turn any `Symbol` keys in `varmap` to the appropriate symbolic variables in `sys`. Any
59+
symbols that cannot be converted are ignored.
60+
"""
61+
function symbols_to_symbolics!(sys::AbstractSystem, varmap::AbstractDict)
62+
if is_split(sys)
63+
ic = get_index_cache(sys)
64+
for k in collect(keys(varmap))
65+
k isa Symbol || continue
66+
newk = get(ic.symbol_to_variable, k, nothing)
67+
newk === nothing && continue
68+
varmap[newk] = varmap[k]
69+
delete!(varmap, k)
70+
end
71+
else
72+
syms = all_symbols(sys)
73+
for k in collect(keys(varmap))
74+
k isa Symbol || continue
75+
idx = findfirst(syms) do sym
76+
hasname(sym) || return false
77+
name = getname(sym)
78+
return name == k
79+
end
80+
idx === nothing && continue
81+
newk = syms[idx]
82+
if iscall(newk) && operation(newk) === getindex
83+
newk = arguments(newk)[1]
84+
end
85+
varmap[newk] = varmap[k]
86+
delete!(varmap, k)
87+
end
88+
end
89+
end
90+
5591
"""
5692
$(TYPEDSIGNATURES)
5793
@@ -530,8 +566,10 @@ function process_SciMLProblem(
530566
pType = typeof(pmap)
531567
_u0map = u0map
532568
u0map = to_varmap(u0map, dvs)
569+
symbols_to_symbolics!(sys, u0map)
533570
_pmap = pmap
534571
pmap = to_varmap(pmap, ps)
572+
symbols_to_symbolics!(sys, pmap)
535573
defs = add_toterms(recursive_unwrap(defaults(sys)))
536574
cmap, cs = get_cmap(sys)
537575
kwargs = NamedTuple(kwargs)

0 commit comments

Comments
 (0)