Skip to content

Commit 5cfeeeb

Browse files
committed
Handle differentiation chains that are only partly dummy derivatives
1 parent 8274396 commit 5cfeeeb

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,23 @@ function to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar::F,
195195
end
196196
end
197197

198+
#=
199+
function check_diff_graph(var_to_diff, fullvars)
200+
diff_to_var = invview(var_to_diff)
201+
for (iv, v) in enumerate(fullvars)
202+
ov, order = var_from_nested_derivative(v)
203+
graph_order = 0
204+
vv = iv
205+
while true
206+
vv = diff_to_var[vv]
207+
vv === nothing && break
208+
graph_order += 1
209+
end
210+
@assert graph_order==order "graph_order: $graph_order, order: $order for variable $v"
211+
end
212+
end
213+
=#
214+
198215
function tearing_reassemble(state::TearingState, var_eq_matching; simplify = false)
199216
@unpack fullvars, sys = state
200217
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
@@ -237,8 +254,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
237254
possible_x_t[rhs] = i, lhs
238255
end
239256

240-
#removed_eqs = Int[]
241-
#removed_vars = Int[]
257+
if ModelingToolkit.has_iv(state.sys)
258+
iv = get_iv(state.sys)
259+
D = Differential(iv)
260+
else
261+
iv = D = nothing
262+
end
242263
removed_obs = Int[]
243264
diff_to_var = invview(var_to_diff)
244265
dummy_sub = Dict()
@@ -258,7 +279,19 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
258279
neweqs[eq] = substitute(neweqs[eq], dd => v_t)
259280
end
260281
fullvars[dv] = v_t
282+
# If we have:
283+
# x -> D(x) -> D(D(x))
284+
# We need to to transform it to:
285+
# x x_t -> D(x_t)
261286
# update the structural information
287+
if (ddx = var_to_diff[dv]) !== nothing
288+
dv_t = D(v_t)
289+
# TODO: handle this recursively
290+
for eq in 𝑑neighbors(graph, ddx)
291+
neweqs[eq] = substitute(neweqs[eq], fullvars[ddx] => dv_t)
292+
end
293+
fullvars[ddx] = dv_t
294+
end
262295
diff_to_var[dv] = nothing
263296
end
264297
end
@@ -323,12 +356,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
323356
# As a final note, in all the above cases where we need to introduce new
324357
# variables and equations, don't add them when they already exist.
325358

326-
if ModelingToolkit.has_iv(state.sys)
327-
iv = get_iv(state.sys)
328-
D = Differential(iv)
329-
else
330-
iv = D = nothing
331-
end
332359
nvars = ndsts(graph)
333360
processed = falses(nvars)
334361
subinfo = NTuple{3, Int}[]
@@ -488,7 +515,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
488515
# 0 ~ a * var + b
489516
# var ~ -b/a
490517
if ModelingToolkit._iszero(a)
491-
@warn "Tearing: $eq is a singular equation!"
518+
@warn "Tearing: solving $eq for $var is singular!"
492519
#push!(removed_eqs, ieq)
493520
#push!(removed_vars, iv)
494521
else

src/systems/systemstructure.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,17 +315,16 @@ function TearingState(sys; quick_cancel = false, check = true)
315315

316316
nvars = length(fullvars)
317317
diffvars = []
318-
vartype = fill(DIFFERENTIAL_VARIABLE, nvars)
319318
var_to_diff = DiffGraph(nvars, true)
320319
for dervaridx in dervaridxs
321-
vartype[dervaridx] = DERIVATIVE_VARIABLE
322320
dervar = fullvars[dervaridx]
323321
diffvar = arguments(dervar)[1]
324322
diffvaridx = var2idx[diffvar]
325323
push!(diffvars, diffvar)
326324
var_to_diff[diffvaridx] = dervaridx
327325
end
328326

327+
#=
329328
algvars = setdiff(states(sys), diffvars)
330329
for algvar in algvars
331330
# it could be that a variable appeared in the states, but never appeared
@@ -334,10 +333,8 @@ function TearingState(sys; quick_cancel = false, check = true)
334333
#if algvaridx == 0
335334
# check ? throw(InvalidSystemException("The system is missing an equation for $algvar.")) : return nothing
336335
#end
337-
if algvaridx != 0
338-
vartype[algvaridx] = ALGEBRAIC_VARIABLE
339-
end
340336
end
337+
=#
341338

342339
graph = BipartiteGraph(neqs, nvars, Val(false))
343340
for (ie, vars) in enumerate(symbolic_incidence), v in vars

0 commit comments

Comments
 (0)