Skip to content

Commit 0325d7b

Browse files
authored
Merge pull request #1855 from SciML/myb/lowering_fix
Handle differentiation chains that are only partly dummy derivatives
2 parents 10edb8b + 64ed5e4 commit 0325d7b

File tree

5 files changed

+53
-24
lines changed

5 files changed

+53
-24
lines changed

src/structural_transformation/pantelides.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ end
7575
Perform Pantelides algorithm.
7676
"""
7777
function pantelides!(state::TransformationState, ag::Union{AliasGraph, Nothing} = nothing;
78-
maxiters = 8000)
78+
finalize = true, maxiters = 8000)
7979
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure
8080
neqs = nsrcs(graph)
8181
nvars = nv(var_to_diff)
@@ -163,6 +163,11 @@ function pantelides!(state::TransformationState, ag::Union{AliasGraph, Nothing}
163163
pathfound ||
164164
error("maxiters=$maxiters reached! File a bug report if your system has a reasonable index (<100), and you are using the default `maxiters`. Try to increase the maxiters by `pantelides(sys::ODESystem; maxiters=1_000_000)` if your system has an incredibly high index and it is truly extremely large.")
165165
end # for k in 1:neqs′
166+
167+
finalize && for var in 1:ndsts(graph)
168+
varwhitelist[var] && continue
169+
var_eq_matching[var] = unassigned
170+
end
166171
return var_eq_matching
167172
end
168173

@@ -175,6 +180,6 @@ instead, which calls this function internally.
175180
"""
176181
function dae_index_lowering(sys::ODESystem; kwargs...)
177182
state = TearingState(sys)
178-
var_eq_matching = pantelides!(state; kwargs...)
183+
var_eq_matching = pantelides!(state; finalize = false, kwargs...)
179184
return invalidate_cache!(pantelides_reassemble(state, var_eq_matching))
180185
end

src/structural_transformation/partial_state_selection.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,6 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match
143143
level
144144
end
145145

146-
# TODO: Should pantelides just return this?
147-
for var in 1:ndsts(graph)
148-
if varlevel[var] !== 0
149-
var_eq_matching[var] = unassigned
150-
end
151-
end
152-
153146
var_eq_matching = pss_graph_modia!(structure,
154147
complete(var_eq_matching), varlevel, inv_varlevel,
155148
inv_eqlevel)
@@ -267,6 +260,10 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
267260
end
268261
end
269262

263+
if (n_diff_eqs = count(!isnothing, diff_to_eq)) !=
264+
(n_dummys = length(dummy_derivatives))
265+
@warn "The number of dummy derivatives ($n_dummys) does not match the number of differentiated equations ($n_diff_eqs)."
266+
end
270267
dummy_derivatives_set = BitSet(dummy_derivatives)
271268
# We can eliminate variables that are not a selected state (differential
272269
# variables). Selected states are differentiated variables that are not

src/structural_transformation/symbolics_tearing.jl

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,23 @@ function to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar::F,
195195
end
196196
end
197197

198+
#=
199+
function check_diff_graph(var_to_diff, fullvars)
200+
diff_to_var = invview(var_to_diff)
201+
for (iv, v) in enumerate(fullvars)
202+
ov, order = var_from_nested_derivative(v)
203+
graph_order = 0
204+
vv = iv
205+
while true
206+
vv = diff_to_var[vv]
207+
vv === nothing && break
208+
graph_order += 1
209+
end
210+
@assert graph_order==order "graph_order: $graph_order, order: $order for variable $v"
211+
end
212+
end
213+
=#
214+
198215
function tearing_reassemble(state::TearingState, var_eq_matching; simplify = false)
199216
@unpack fullvars, sys = state
200217
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
@@ -237,8 +254,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
237254
possible_x_t[rhs] = i, lhs
238255
end
239256

240-
#removed_eqs = Int[]
241-
#removed_vars = Int[]
257+
if ModelingToolkit.has_iv(state.sys)
258+
iv = get_iv(state.sys)
259+
D = Differential(iv)
260+
else
261+
iv = D = nothing
262+
end
242263
removed_obs = Int[]
243264
diff_to_var = invview(var_to_diff)
244265
dummy_sub = Dict()
@@ -258,7 +279,22 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
258279
neweqs[eq] = substitute(neweqs[eq], dd => v_t)
259280
end
260281
fullvars[dv] = v_t
282+
# If we have:
283+
# x -> D(x) -> D(D(x))
284+
# We need to to transform it to:
285+
# x x_t -> D(x_t)
261286
# update the structural information
287+
dx = dv
288+
x_t = v_t
289+
while (ddx = var_to_diff[dx]) !== nothing
290+
dx_t = D(x_t)
291+
for eq in 𝑑neighbors(graph, ddx)
292+
neweqs[eq] = substitute(neweqs[eq], fullvars[ddx] => dx_t)
293+
end
294+
fullvars[ddx] = dx_t
295+
dx = ddx
296+
x_t = dx_t
297+
end
262298
diff_to_var[dv] = nothing
263299
end
264300
end
@@ -323,12 +359,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
323359
# As a final note, in all the above cases where we need to introduce new
324360
# variables and equations, don't add them when they already exist.
325361

326-
if ModelingToolkit.has_iv(state.sys)
327-
iv = get_iv(state.sys)
328-
D = Differential(iv)
329-
else
330-
iv = D = nothing
331-
end
332362
nvars = ndsts(graph)
333363
processed = falses(nvars)
334364
subinfo = NTuple{3, Int}[]
@@ -488,7 +518,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
488518
# 0 ~ a * var + b
489519
# var ~ -b/a
490520
if ModelingToolkit._iszero(a)
491-
@warn "Tearing: $eq is a singular equation!"
521+
@warn "Tearing: solving $eq for $var is singular!"
492522
#push!(removed_eqs, ieq)
493523
#push!(removed_vars, iv)
494524
else

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = paramete
189189
end
190190

191191
function calculate_massmatrix(sys::AbstractODESystem; simplify = false)
192-
eqs = [eq for eq in full_equations(sys) if !isdifferenceeq(eq)]
192+
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
193193
dvs = states(sys)
194194
M = zeros(length(eqs), length(eqs))
195195
state2idx = Dict(s => i for (i, s) in enumerate(dvs))

src/systems/systemstructure.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,17 +315,16 @@ function TearingState(sys; quick_cancel = false, check = true)
315315

316316
nvars = length(fullvars)
317317
diffvars = []
318-
vartype = fill(DIFFERENTIAL_VARIABLE, nvars)
319318
var_to_diff = DiffGraph(nvars, true)
320319
for dervaridx in dervaridxs
321-
vartype[dervaridx] = DERIVATIVE_VARIABLE
322320
dervar = fullvars[dervaridx]
323321
diffvar = arguments(dervar)[1]
324322
diffvaridx = var2idx[diffvar]
325323
push!(diffvars, diffvar)
326324
var_to_diff[diffvaridx] = dervaridx
327325
end
328326

327+
#=
329328
algvars = setdiff(states(sys), diffvars)
330329
for algvar in algvars
331330
# it could be that a variable appeared in the states, but never appeared
@@ -334,10 +333,8 @@ function TearingState(sys; quick_cancel = false, check = true)
334333
#if algvaridx == 0
335334
# check ? throw(InvalidSystemException("The system is missing an equation for $algvar.")) : return nothing
336335
#end
337-
if algvaridx != 0
338-
vartype[algvaridx] = ALGEBRAIC_VARIABLE
339-
end
340336
end
337+
=#
341338

342339
graph = BipartiteGraph(neqs, nvars, Val(false))
343340
for (ie, vars) in enumerate(symbolic_incidence), v in vars

0 commit comments

Comments
 (0)