Skip to content

Commit ba519bf

Browse files
Merge pull request #3146 from AayushSabharwal/as/initprob-dependent-unknowns
fix: construct `initializeprob` if initial value is symbolic
2 parents f8f8142 + 2184598 commit ba519bf

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

src/systems/problem_utils.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,24 @@ function to_varmap(vals, varlist::Vector)
171171
check_eqs_u0(varlist, varlist, vals)
172172
vals = vec(varlist) .=> vec(vals)
173173
end
174-
return anydict(unwrap(k) => unwrap(v) for (k, v) in anydict(vals))
174+
return recursive_unwrap(anydict(vals))
175+
end
176+
177+
"""
178+
$(TYPEDSIGNATURES)
179+
180+
Recursively call `Symbolics.unwrap` on `x`. Useful when `x` is an array of (potentially)
181+
symbolic values, all of which need to be unwrapped. Specializes when `x isa AbstractDict`
182+
to unwrap keys and values, returning an `AnyDict`.
183+
"""
184+
function recursive_unwrap(x::AbstractArray)
185+
symbolic_type(x) == ArraySymbolic() ? unwrap(x) : recursive_unwrap.(x)
186+
end
187+
188+
recursive_unwrap(x) = unwrap(x)
189+
190+
function recursive_unwrap(x::AbstractDict)
191+
return anydict(unwrap(k) => recursive_unwrap(v) for (k, v) in x)
175192
end
176193

177194
"""
@@ -262,7 +279,7 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
262279
end
263280
vals = map(x -> varmap[x], vars)
264281

265-
if container_type <: Union{AbstractDict, Tuple, Nothing}
282+
if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters}
266283
container_type = Array
267284
end
268285

@@ -410,7 +427,7 @@ function process_SciMLProblem(
410427
u0map = to_varmap(u0map, dvs)
411428
_pmap = pmap
412429
pmap = to_varmap(pmap, ps)
413-
defs = add_toterms(defaults(sys))
430+
defs = add_toterms(recursive_unwrap(defaults(sys)))
414431
cmap, cs = get_cmap(sys)
415432
kwargs = NamedTuple(kwargs)
416433

@@ -433,9 +450,14 @@ function process_SciMLProblem(
433450
solvablepars = [p
434451
for p in parameters(sys)
435452
if is_parameter_solvable(p, pmap, defs, guesses)]
453+
has_dependent_unknowns = any(unknowns(sys)) do sym
454+
val = get(op, sym, nothing)
455+
val === nothing && return false
456+
return symbolic_type(val) != NotSymbolic() || is_array_of_symbolics(val)
457+
end
436458
if build_initializeprob &&
437459
(((implicit_dae || has_observed_u0s || !isempty(missing_unknowns) ||
438-
!isempty(solvablepars)) &&
460+
!isempty(solvablepars) || has_dependent_unknowns) &&
439461
get_tearing_state(sys) !== nothing) ||
440462
!isempty(initialization_equations(sys))) && t !== nothing
441463
initializeprob = ModelingToolkit.InitializationProblem(

test/initial_values.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,20 @@ end
119119
prob = ODEProblem(sys, [], (1.0, 2.0), [])
120120
@test prob[x] == 1.0
121121
@test prob.ps[p] == 2.0
122+
123+
@testset "Array of symbolics is unwrapped" begin
124+
@variables x(t)[1:2] y(t)
125+
@mtkbuild sys = ODESystem([D(x) ~ x, D(y) ~ t], t; defaults = [x => [y, 3.0]])
126+
prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0))
127+
@test eltype(prob.u0) <: Float64
128+
prob = ODEProblem(sys, [x => [y, 4.0], y => 2.0], (0.0, 1.0))
129+
@test eltype(prob.u0) <: Float64
130+
end
131+
132+
@testset "split=false systems with all parameter defaults" begin
133+
@variables x(t) = 1.0
134+
@parameters p=1.0 q=2.0 r=3.0
135+
@mtkbuild sys=ODESystem(D(x) ~ p * x + q * t + r, t) split=false
136+
prob = @test_nowarn ODEProblem(sys, [], (0.0, 1.0))
137+
@test prob.p isa Vector{Float64}
138+
end

test/initializationsystem.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,3 +844,19 @@ end
844844
isys = ModelingToolkit.generate_initializesystem(sys)
845845
@test isequal(defaults(isys)[y], 2x + 1)
846846
end
847+
848+
@testset "Create initializeprob when unknown has dependent value" begin
849+
@variables x(t) y(t)
850+
@mtkbuild sys = ODESystem([D(x) ~ x, D(y) ~ t * y], t; defaults = [x => 2y])
851+
prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0))
852+
@test prob.f.initializeprob !== nothing
853+
integ = init(prob)
854+
@test integ[x] 2.0
855+
856+
@variables x(t)[1:2] y(t)
857+
@mtkbuild sys = ODESystem([D(x) ~ x, D(y) ~ t], t; defaults = [x => [y, 3.0]])
858+
prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0))
859+
@test prob.f.initializeprob !== nothing
860+
integ = init(prob)
861+
@test integ[x] [1.0, 3.0]
862+
end

0 commit comments

Comments
 (0)