@@ -12,7 +12,9 @@ function generate_initializesystem(sys::ODESystem;
12
12
algebraic_only = false ,
13
13
check_units = true , check_defguess = false ,
14
14
name = nameof (sys), kwargs... )
15
- vars = unique ([unknowns (sys); getfield .((observed (sys)), :lhs )])
15
+ trueobs = unhack_observed (observed (sys))
16
+ @show trueobs
17
+ vars = unique ([unknowns (sys); getfield .(trueobs, :lhs )])
16
18
vars_set = Set (vars) # for efficient in-lookup
17
19
18
20
eqs = equations (sys)
@@ -24,7 +26,7 @@ function generate_initializesystem(sys::ODESystem;
24
26
D = Differential (get_iv (sys))
25
27
diffmap = merge (
26
28
Dict (eq. lhs => eq. rhs for eq in eqs_diff),
27
- Dict (D (eq. lhs) => D (eq. rhs) for eq in observed (sys) )
29
+ Dict (D (eq. lhs) => D (eq. rhs) for eq in trueobs )
28
30
)
29
31
30
32
# 1) process dummy derivatives and u0map into initialization system
@@ -166,15 +168,14 @@ function generate_initializesystem(sys::ODESystem;
166
168
)
167
169
168
170
# 7) use observed equations for guesses of observed variables if not provided
169
- obseqs = observed (sys)
170
- for eq in obseqs
171
+ for eq in trueobs
171
172
haskey (defs, eq. lhs) && continue
172
173
any (x -> isequal (default_toterm (x), eq. lhs), keys (defs)) && continue
173
174
174
175
defs[eq. lhs] = eq. rhs
175
176
end
176
177
177
- eqs_ics = Symbolics. substitute .([eqs_ics; obseqs ], (paramsubs,))
178
+ eqs_ics = Symbolics. substitute .([eqs_ics; trueobs ], (paramsubs,))
178
179
vars = [vars; collect (values (paramsubs))]
179
180
for k in keys (defs)
180
181
defs[k] = substitute (defs[k], paramsubs)
@@ -324,3 +325,37 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
324
325
return nothing , nothing , nothing , nothing
325
326
end
326
327
end
328
+
329
+ """
330
+ Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with
331
+ initialization.
332
+ """
333
+ function unhack_observed (eqs:: Vector{Equation} )
334
+ subs = Dict ()
335
+ tempvars = Set ()
336
+ rm_idxs = Int[]
337
+ for (i, eq) in enumerate (eqs)
338
+ iscall (eq. rhs) || continue
339
+ if operation (eq. rhs) == StructuralTransformations. change_origin
340
+ push! (rm_idxs, i)
341
+ continue
342
+ end
343
+ if operation (eq. rhs) == StructuralTransformations. getindex_wrapper
344
+ var, idxs = arguments (eq. rhs)
345
+ subs[eq. rhs] = var[idxs... ]
346
+ push! (tempvars, var)
347
+ end
348
+ end
349
+
350
+ for (i, eq) in enumerate (eqs)
351
+ if eq. lhs in tempvars
352
+ subs[eq. lhs] = eq. rhs
353
+ push! (rm_idxs, i)
354
+ end
355
+ end
356
+
357
+ eqs = eqs[setdiff (eachindex (eqs), rm_idxs)]
358
+ return map (eqs) do eq
359
+ fixpoint_sub (eq. lhs, subs) ~ fixpoint_sub (eq. rhs, subs)
360
+ end
361
+ end
0 commit comments