Skip to content

Commit b44667c

Browse files
committed
Fix misuses of isdiffvar
1 parent 270c0ff commit b44667c

File tree

4 files changed

+68
-66
lines changed

4 files changed

+68
-66
lines changed

src/structural_transformation/codegen.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ const MAX_INLINE_NLSOLVE_SIZE = 8
77

88
function torn_system_with_nlsolve_jacobian_sparsity(state, var_eq_matching, var_sccs,
99
nlsolve_scc_idxs, eqs_idxs, states_idxs)
10-
fullvars = state.fullvars
1110
graph = state.structure.graph
1211

1312
# The sparsity pattern of `nlsolve(f, u, p)` w.r.t `p` is difficult to
@@ -72,7 +71,6 @@ function torn_system_with_nlsolve_jacobian_sparsity(state, var_eq_matching, var_
7271

7372
var2idx = Dict{Int, Int}(v => i for (i, v) in enumerate(states_idxs))
7473
eqs2idx = Dict{Int, Int}(v => i for (i, v) in enumerate(eqs_idxs))
75-
nlsolve_vars_set = BitSet(nlsolve_vars)
7674

7775
I = Int[]
7876
J = Int[]

src/structural_transformation/symbolics_tearing.jl

Lines changed: 62 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,43 @@ function substitute_vars!(graph::BipartiteGraph, subs, cache = Int[], callback!
134134
graph
135135
end
136136

137+
function to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar::F) where F
138+
eq = neweqs[ieq]
139+
if !(eq.lhs isa Number && eq.lhs == 0)
140+
eq = 0 ~ eq.rhs - eq.lhs
141+
end
142+
rhs = eq.rhs
143+
if rhs isa Symbolic
144+
# Check if the RHS is solvable in all state derivatives and if those
145+
# the linear terms for them are all zero. If so, move them to the
146+
# LHS.
147+
dervar::Union{Nothing, Int} = nothing
148+
for var in 𝑠neighbors(graph, ieq)
149+
if isdervar(var)
150+
if dervar !== nothing
151+
error("$eq has more than one differentiated variable!")
152+
end
153+
dervar = var
154+
end
155+
end
156+
dervar === nothing && return 0 ~ rhs
157+
new_lhs = var = fullvars[dervar]
158+
# 0 ~ a * D(x) + b
159+
# D(x) ~ -b/a
160+
a, b, islinear = linear_expansion(rhs, var)
161+
if !islinear
162+
return 0 ~ rhs
163+
end
164+
new_rhs = -b / a
165+
return new_lhs ~ new_rhs
166+
else # a number
167+
if abs(rhs) > 100eps(float(rhs))
168+
@warn "The equation $eq is not consistent. It simplifed to 0 == $rhs."
169+
end
170+
return nothing
171+
end
172+
end
173+
137174
function tearing_reassemble(state::TearingState, var_eq_matching; simplify = false)
138175
@unpack fullvars, sys = state
139176
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
@@ -205,13 +242,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
205242
# variables that appear differentiated are differential variables.
206243

207244
### extract partition information
208-
is_solvable(eq, iv) = isa(eq, Int) && BipartiteEdge(eq, iv) in solvable_graph
209-
210-
solved_equations = Int[]
211-
solved_variables = Int[]
245+
is_solvable = let solvable_graph = solvable_graph
246+
(eq, iv) -> eq isa Int && iv isa Int && BipartiteEdge(eq, iv) in solvable_graph
247+
end
212248

213249
# if var is like D(x)
214-
isdiffvar = let diff_to_var = diff_to_var
250+
isdervar = let diff_to_var = diff_to_var
215251
var -> diff_to_var[var] !== nothing
216252
end
217253

@@ -387,57 +423,35 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
387423
empty!(subs)
388424
end
389425

390-
# Rewrite remaining equations in terms of solved variables
391-
function to_mass_matrix_form(ieq)
392-
eq = neweqs[ieq]
393-
if !(eq.lhs isa Number && eq.lhs == 0)
394-
eq = 0 ~ eq.rhs - eq.lhs
395-
end
396-
rhs = eq.rhs
397-
if rhs isa Symbolic
398-
# Check if the RHS is solvable in all state derivatives and if those
399-
# the linear terms for them are all zero. If so, move them to the
400-
# LHS.
401-
dterms = [var for var in 𝑠neighbors(graph, ieq) if isdiffvar(var)]
402-
length(dterms) == 0 && return 0 ~ rhs
403-
new_rhs = rhs
404-
new_lhs = 0
405-
for iv in dterms
406-
var = fullvars[iv]
407-
# 0 ~ a * D(x) + b
408-
# D(x) ~ -b/a
409-
a, b, islinear = linear_expansion(new_rhs, var)
410-
if !islinear
411-
return 0 ~ rhs
412-
end
413-
new_lhs += var
414-
new_rhs = -b / a
415-
end
416-
return new_lhs ~ new_rhs
417-
else # a number
418-
if abs(rhs) > 100eps(float(rhs))
419-
@warn "The equation $eq is not consistent. It simplifed to 0 == $rhs."
420-
end
421-
return nothing
422-
end
423-
end
424-
425426
diffeq_idxs = BitSet()
426427
final_eqs = Equation[]
427428
var_rename = zeros(Int, length(var_eq_matching))
428429
subeqs = Equation[]
430+
solved_equations = Int[]
431+
solved_variables = Int[]
429432
idx = 0
430433
# Solve solvable equations
431434
for (iv, ieq) in enumerate(var_eq_matching)
435+
if is_solvable(ieq, iv)
436+
if isdervar(iv)
437+
var_rename[iv] = (idx += 1)
438+
end
439+
var_rename[iv] = -1
440+
else
441+
var_rename[iv] = (idx += 1)
442+
end
443+
end
444+
neqs = nsrcs(graph)
445+
for (ieq, iv) in enumerate(invview(var_eq_matching))
446+
ieq > neqs && break
432447
if is_solvable(ieq, iv)
433448
# We don't solve differential equations, but we will need to try to
434449
# convert it into the mass matrix form.
435450
# We cannot solve the differential variable like D(x)
436-
if isdiffvar(iv)
451+
if isdervar(iv)
437452
# TODO: what if `to_mass_matrix_form(ieq)` returns `nothing`?
438-
push!(final_eqs, to_mass_matrix_form(ieq))
453+
push!(final_eqs, to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar))
439454
push!(diffeq_idxs, ieq)
440-
var_rename[iv] = (idx += 1)
441455
continue
442456
end
443457
eq = neweqs[ieq]
@@ -456,11 +470,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
456470
push!(solved_equations, ieq)
457471
push!(solved_variables, iv)
458472
end
459-
var_rename[iv] = -1
460473
else
461-
var_rename[iv] = (idx += 1)
474+
push!(final_eqs, to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar))
462475
end
463476
end
477+
# TODO: BLT sorting
478+
neweqs = final_eqs
464479

465480
if isempty(solved_equations)
466481
deps = Vector{Int}[]
@@ -476,16 +491,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
476491
for j in toporder]
477492
end
478493

479-
# TODO: BLT sorting
480-
# Rewrite remaining equations in terms of solved variables
481-
solved_eq_set = BitSet(solved_equations)
482-
for ieq in 1:length(neweqs)
483-
(ieq in diffeq_idxs || ieq in solved_eq_set) && continue
484-
maybe_eq = to_mass_matrix_form(ieq)
485-
maybe_eq === nothing || push!(final_eqs, maybe_eq)
486-
end
487-
neweqs = final_eqs
488-
489494
# Contract the vertices in the structure graph to make the structure match
490495
# the new reality of the system we've just created.
491496
#
@@ -494,7 +499,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
494499

495500
# Update system
496501
solved_variables_set = BitSet(solved_variables)
497-
active_vars = setdiff!(setdiff(BitSet(1:length(fullvars)), solved_variables_set),
502+
active_vars = setdiff!(setdiff!(BitSet(1:length(fullvars)), solved_variables_set),
498503
removed_vars)
499504
new_var_to_diff = complete(DiffGraph(length(active_vars)))
500505
idx = 0
@@ -525,7 +530,7 @@ end
525530
function tearing(state::TearingState; kwargs...)
526531
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
527532
complete!(state.structure)
528-
@unpack graph, solvable_graph = state.structure
533+
@unpack graph = state.structure
529534
algvars = BitSet(findall(v -> isalgvar(state.structure, v), 1:ndsts(graph)))
530535
aeqs = algeqs(state.structure)
531536
var_eq_matching′ = tear_graph_modia(state.structure;

src/systems/alias_elimination.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,8 +450,8 @@ count_nonzeros(a::AbstractArray) = count(!iszero, a)
450450
# Here we have a guarantee that they won't, so we can make this identification
451451
count_nonzeros(a::SparseVector) = nnz(a)
452452

453-
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL, only_algebraic,
454-
irreducibles)
453+
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL, only_algebraic = true,
454+
irreducibles = ())
455455
mm = copy(mm_orig)
456456
is_linear_equations = falses(size(AsSubMatrix(mm_orig), 1))
457457

src/systems/systemstructure.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,14 @@ Base.@kwdef mutable struct SystemStructure
171171
graph::BipartiteGraph{Int, Nothing}
172172
solvable_graph::Union{BipartiteGraph{Int, Nothing}, Nothing}
173173
end
174-
function isdervar(s::SystemStructure, i)
175-
s.var_to_diff[i] === nothing &&
176-
invview(s.var_to_diff)[i] !== nothing
177-
end
174+
isdervar(s::SystemStructure, i) = invview(s.var_to_diff)[i] !== nothing
178175
function isalgvar(s::SystemStructure, i)
179176
s.var_to_diff[i] === nothing &&
180177
invview(s.var_to_diff)[i] === nothing
181178
end
182-
isdiffvar(s::SystemStructure, i) = s.var_to_diff[i] !== nothing
179+
function isdiffvar(s::SystemStructure, i)
180+
s.var_to_diff[i] !== nothing && invview(s.var_to_diff)[i] === nothing
181+
end
183182

184183
function dervars_range(s::SystemStructure)
185184
Iterators.filter(Base.Fix1(isdervar, s), Base.OneTo(ndsts(s.graph)))

0 commit comments

Comments
 (0)