Skip to content

Commit a24a764

Browse files
committed
Check if x_t is in the observed variables before creating
1 parent 2d8d79a commit a24a764

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ function substitute_vars!(graph::BipartiteGraph, subs, cache = Int[], callback!
135135
end
136136

137137
function tearing_reassemble(state::TearingState, var_eq_matching; simplify = false)
138-
fullvars = state.fullvars
138+
@unpack fullvars, sys = state
139139
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
140140

141141
neweqs = collect(equations(state))
@@ -165,16 +165,32 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
165165
# Step 1:
166166
# Replace derivatives of non-selected states by dummy derivatives
167167

168+
possible_x_t = Dict()
169+
oldobs = observed(sys)
170+
for (i, eq) in enumerate(oldobs)
171+
lhs = eq.lhs
172+
rhs = eq.rhs
173+
isdifferential(lhs) && continue
174+
# TODO: should we hanlde negative alias as well?
175+
isdifferential(rhs) || continue
176+
possible_x_t[rhs] = i, lhs
177+
end
178+
168179
removed_eqs = Int[]
169180
removed_vars = Int[]
181+
removed_obs = Int[]
170182
diff_to_var = invview(var_to_diff)
171183
for var in 1:length(fullvars)
172184
dv = var_to_diff[var]
173185
dv === nothing && continue
174186
if var_eq_matching[var] !== SelectedState()
175187
dd = fullvars[dv]
176-
# TODO: check if observed has it
177-
v_t = diff2term(unwrap(dd))
188+
if (i_v_t = get(possible_x_t, rhs, nothing)) === nothing
189+
v_t = diff2term(unwrap(dd))
190+
else
191+
idx, v_t = i_v_t
192+
push!(removed_obs, idx)
193+
end
178194
for eq in 𝑑neighbors(graph, dv)
179195
neweqs[eq] = substitute(neweqs[eq], fullvars[dv] => v_t)
180196
end
@@ -312,8 +328,14 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
312328
dx = fullvars[dx_idx]
313329
end
314330

315-
# TODO: check if it's already in observed
316-
x_t = ModelingToolkit.lower_varname(ogx, iv, o)
331+
if (i_x_t = get(possible_x_t, dx, nothing)) === nothing &&
332+
(ogidx !== nothing &&
333+
(i_x_t = get(possible_x_t, fullvars[ogidx], nothing)) === nothing)
334+
x_t = ModelingToolkit.lower_varname(ogx, iv, o)
335+
else
336+
idx, x_t = i_x_t
337+
push!(removed_obs, idx)
338+
end
317339
push!(fullvars, x_t)
318340
x_t_idx = add_vertex!(var_to_diff)
319341
add_vertex!(graph, DST)
@@ -491,7 +513,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
491513
sys = state.sys
492514
@set! sys.eqs = neweqs
493515
@set! sys.states = [fullvars[i] for i in active_vars if diff_to_var[i] === nothing]
494-
@set! sys.observed = [observed(sys); subeqs]
516+
deleteat!(oldobs, sort!(removed_obs))
517+
@set! sys.observed = [oldobs; subeqs]
495518
@set! sys.substitutions = Substitutions(subeqs, deps)
496519
@set! state.sys = sys
497520
@set! sys.tearing_state = state

src/systems/diffeqs/odesystem.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ function build_explicit_observed_function(sys, ts;
282282
# the expression depends on everything before the `maxidx`.
283283
subs = Dict()
284284
maxidx = 0
285-
for (i, s) in enumerate(dep_vars)
285+
for s in dep_vars
286286
idx = get(observed_idx, s, nothing)
287287
if idx !== nothing
288288
idx > maxidx && (maxidx = idx)
@@ -307,7 +307,27 @@ function build_explicit_observed_function(sys, ts;
307307
end
308308
end
309309
ts = map(t -> substitute(t, subs), ts)
310-
obsexprs = map(eq -> eq.lhs eq.rhs, obs[1:maxidx])
310+
obsexprs = []
311+
eqs_cache = Ref{Any}(nothing)
312+
for i in 1:maxidx
313+
eq = obs[i]
314+
lhs = eq.lhs
315+
rhs = eq.rhs
316+
vars!(vars, rhs)
317+
for v in vars
318+
isdifferential(v) || continue
319+
if eqs_cache[] === nothing
320+
eqs_cache[] = Dict(eq.lhs => eq.rhs for eq in equations(sys))
321+
end
322+
eqs_dict = eqs_cache[]
323+
rhs = get(eqs_dict, v, nothing)
324+
if rhs === nothing
325+
error("Observed variables depends on differentiated variable $v, but it's not explicit solved. Fix file an issue if you are sure that the system is valid.")
326+
end
327+
end
328+
empty!(vars)
329+
push!(obsexprs, lhs rhs)
330+
end
311331

312332
dvs = DestructuredArgs(states(sys), inbounds = !checkbounds)
313333
ps = DestructuredArgs(parameters(sys), inbounds = !checkbounds)

0 commit comments

Comments
 (0)