Skip to content

Commit c023e7e

Browse files
fix: undo the hack in generate_initializesystem
1 parent 6c3576f commit c023e7e

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ function generate_initializesystem(sys::ODESystem;
1212
algebraic_only = false,
1313
check_units = true, check_defguess = false,
1414
name = nameof(sys), kwargs...)
15-
vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)])
15+
trueobs = unhack_observed(observed(sys))
16+
@show trueobs
17+
vars = unique([unknowns(sys); getfield.(trueobs, :lhs)])
1618
vars_set = Set(vars) # for efficient in-lookup
1719

1820
eqs = equations(sys)
@@ -24,7 +26,7 @@ function generate_initializesystem(sys::ODESystem;
2426
D = Differential(get_iv(sys))
2527
diffmap = merge(
2628
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
27-
Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys))
29+
Dict(D(eq.lhs) => D(eq.rhs) for eq in trueobs)
2830
)
2931

3032
# 1) process dummy derivatives and u0map into initialization system
@@ -166,15 +168,14 @@ function generate_initializesystem(sys::ODESystem;
166168
)
167169

168170
# 7) use observed equations for guesses of observed variables if not provided
169-
obseqs = observed(sys)
170-
for eq in obseqs
171+
for eq in trueobs
171172
haskey(defs, eq.lhs) && continue
172173
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue
173174

174175
defs[eq.lhs] = eq.rhs
175176
end
176177

177-
eqs_ics = Symbolics.substitute.([eqs_ics; obseqs], (paramsubs,))
178+
eqs_ics = Symbolics.substitute.([eqs_ics; trueobs], (paramsubs,))
178179
vars = [vars; collect(values(paramsubs))]
179180
for k in keys(defs)
180181
defs[k] = substitute(defs[k], paramsubs)
@@ -324,3 +325,37 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
324325
return nothing, nothing, nothing, nothing
325326
end
326327
end
328+
329+
"""
330+
Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with
331+
initialization.
332+
"""
333+
function unhack_observed(eqs::Vector{Equation})
334+
subs = Dict()
335+
tempvars = Set()
336+
rm_idxs = Int[]
337+
for (i, eq) in enumerate(eqs)
338+
iscall(eq.rhs) || continue
339+
if operation(eq.rhs) == StructuralTransformations.change_origin
340+
push!(rm_idxs, i)
341+
continue
342+
end
343+
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
344+
var, idxs = arguments(eq.rhs)
345+
subs[eq.rhs] = var[idxs...]
346+
push!(tempvars, var)
347+
end
348+
end
349+
350+
for (i, eq) in enumerate(eqs)
351+
if eq.lhs in tempvars
352+
subs[eq.lhs] = eq.rhs
353+
push!(rm_idxs, i)
354+
end
355+
end
356+
357+
eqs = eqs[setdiff(eachindex(eqs), rm_idxs)]
358+
return map(eqs) do eq
359+
fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs)
360+
end
361+
end

test/structural_transformation/utils.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ end
5353
@test any(eq -> isequal(eq.lhs, z), observed(sys))
5454
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn])
5555
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
56+
57+
isys = ModelingToolkit.generate_initializesystem(sys)
58+
@test length(unknowns(isys)) == 5
59+
@test length(equations(isys)) == 4
60+
@test !any(equations(isys)) do eq
61+
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
62+
StructuralTransformations.change_origin]
63+
end
5664
end
5765

5866
@testset "scalarized array observed calling same function multiple times" begin
@@ -69,4 +77,12 @@ end
6977
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2])
7078
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
7179
@test val[] == 1
80+
81+
isys = ModelingToolkit.generate_initializesystem(sys)
82+
@test length(unknowns(isys)) == 3
83+
@test length(equations(isys)) == 2
84+
@test !any(equations(isys)) do eq
85+
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
86+
StructuralTransformations.change_origin]
87+
end
7288
end

0 commit comments

Comments
 (0)