Skip to content

Commit f0ec298

Browse files
committed
Add crude irreducible handling
1 parent 8c16469 commit f0ec298

File tree

4 files changed

+121
-99
lines changed

4 files changed

+121
-99
lines changed

src/systems/abstractsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ function linearization_function(sys::AbstractSystem, inputs,
10241024
input_idxs = input_idxs,
10251025
sts = states(sys),
10261026
fun = ODEFunction(sys),
1027-
h = ModelingToolkit.build_explicit_observed_function(sys, outputs),
1027+
h = build_explicit_observed_function(sys, outputs),
10281028
chunk = ForwardDiff.Chunk(input_idxs)
10291029

10301030
function (u, p, t)

src/systems/alias_elimination.jl

Lines changed: 105 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -54,23 +54,16 @@ function alias_elimination!(state::TearingState)
5454
end
5555

5656
subs = Dict()
57+
obs = Equation[]
5758
# If we encounter y = -D(x), then we need to expand the derivative when
5859
# D(y) appears in the equation, so that D(-D(x)) becomes -D(D(x)).
5960
to_expand = Int[]
6061
diff_to_var = invview(var_to_diff)
61-
# TODO/FIXME: this also needs to be computed recursively because we need to
62-
# follow the alias graph like `a => b => c` and make sure that the final
63-
# graph always contains the destination.
64-
extra_eqs = Equation[]
65-
extra_vars = BitSet()
6662
for (v, (coeff, alias)) in pairs(ag)
67-
if iszero(alias) || !(isempty(𝑑neighbors(graph, alias)) && isempty(𝑑neighbors(graph, v)))
68-
subs[fullvars[v]] = iszero(coeff) ? 0 : coeff * fullvars[alias]
69-
else
70-
push!(extra_eqs, 0 ~ coeff * fullvars[alias] - fullvars[v])
71-
push!(extra_vars, v)
72-
#@show fullvars[v] => fullvars[alias]
73-
end
63+
lhs = fullvars[v]
64+
rhs = iszero(coeff) ? 0 : coeff * fullvars[alias]
65+
subs[lhs] = rhs
66+
v != alias && push!(obs, lhs ~ rhs)
7467
if coeff == -1
7568
# if `alias` is like -D(x)
7669
diff_to_var[alias] === nothing && continue
@@ -101,7 +94,7 @@ function alias_elimination!(state::TearingState)
10194
idx = 0
10295
cursor = 1
10396
ndels = length(dels)
104-
for (i, e) in enumerate(old_to_new)
97+
for i in eachindex(old_to_new)
10598
if cursor <= ndels && i == dels[cursor]
10699
cursor += 1
107100
old_to_new[i] = -1
@@ -111,7 +104,9 @@ function alias_elimination!(state::TearingState)
111104
old_to_new[i] = idx
112105
end
113106

107+
lineqs = BitSet(old_to_new[e] for e in mm.nzrows)
114108
for (ieq, eq) in enumerate(eqs)
109+
ieq in lineqs && continue
115110
eqs[ieq] = substitute(eq, subs)
116111
end
117112

@@ -131,7 +126,7 @@ function alias_elimination!(state::TearingState)
131126
sys = state.sys
132127
@set! sys.eqs = eqs
133128
@set! sys.states = newstates
134-
@set! sys.observed = [observed(sys); [lhs ~ rhs for (lhs, rhs) in pairs(subs)]]
129+
@set! sys.observed = [observed(sys); obs]
135130
return invalidate_cache!(sys)
136131
end
137132

@@ -212,7 +207,8 @@ function Base.getindex(ag::AliasGraph, i::Integer)
212207
coeff, var = (sign(r), abs(r))
213208
nc = coeff
214209
av = var
215-
if var in keys(ag)
210+
# We support `x -> -x` as an alias.
211+
if var != i && var in keys(ag)
216212
# Amortized lookup. Check if since we last looked this up, our alias was
217213
# itself aliased. If so, just adjust the alias table.
218214
ac, av = ag[var]
@@ -248,6 +244,10 @@ end
248244

249245
function Base.setindex!(ag::AliasGraph, p::Pair{Int, Int}, i::Integer)
250246
(c, v) = p
247+
if c == 0 || v == 0
248+
ag[i] = 0
249+
return p
250+
end
251251
@assert v != 0 && c in (-1, 1)
252252
if ag.aliasto[i] === nothing
253253
push!(ag.eliminated, i)
@@ -280,22 +280,27 @@ function reduce!(mm::SparseMatrixCLIL, ag::AliasGraph)
280280
c = rs[j]
281281
_alias = get(ag, c, nothing)
282282
if _alias !== nothing
283-
push!(dels, j)
284283
coeff, alias = _alias
285-
iszero(coeff) && (j += 1; continue)
286-
inc = coeff * rvals[j]
287-
i = searchsortedfirst(rs, alias)
288-
if i > length(rs) || rs[i] != alias
289-
# if we add a variable to what we already visited, make sure
290-
# to bump the cursor.
291-
j += i <= j
292-
for (i, e) in enumerate(dels)
293-
e >= i && (dels[i] += 1)
294-
end
295-
insert!(rs, i, alias)
296-
insert!(rvals, i, inc)
284+
if alias == c
285+
i = searchsortedfirst(rs, alias)
286+
rvals[i] *= coeff
297287
else
298-
rvals[i] += inc
288+
push!(dels, j)
289+
iszero(coeff) && (j += 1; continue)
290+
inc = coeff * rvals[j]
291+
i = searchsortedfirst(rs, alias)
292+
if i > length(rs) || rs[i] != alias
293+
# if we add a variable to what we already visited, make sure
294+
# to bump the cursor.
295+
j += i <= j
296+
for (i, e) in enumerate(dels)
297+
e >= i && (dels[i] += 1)
298+
end
299+
insert!(rs, i, alias)
300+
insert!(rvals, i, inc)
301+
else
302+
rvals[i] += inc
303+
end
299304
end
300305
end
301306
j += 1
@@ -651,6 +656,7 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, irreducibles = ())
651656
ag[v] = 0
652657
end
653658

659+
echelon_mm = copy(mm)
654660
lss! = lss(mm, pivots, ag)
655661
# Step 2.1: Go backwards, collecting eliminated variables and substituting
656662
# alias as we go.
@@ -677,7 +683,7 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, irreducibles = ())
677683
reduced && while any(lss!, 1:rank2)
678684
end
679685

680-
return mm
686+
return mm, echelon_mm
681687
end
682688

683689
function mark_processed!(processed, var_to_diff, v)
@@ -695,7 +701,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
695701
#
696702
nvars = ndsts(graph)
697703
ag = AliasGraph(nvars)
698-
mm = simple_aliases!(ag, graph, var_to_diff, mm_orig)
704+
mm, echelon_mm = simple_aliases!(ag, graph, var_to_diff, mm_orig)
699705
state = Main._state[]
700706
fullvars = state.fullvars
701707
for (v, (c, a)) in ag
@@ -718,14 +724,14 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
718724
#
719725
# where `-->` is an edge in `var_to_diff`, `⇒` is an edge in `ag`, and the
720726
# part in the box are purely conceptual, i.e. `D(D(D(z)))` doesn't appear in
721-
# the system.
727+
# the system. We call the variables in the box "virtual" variables.
722728
#
723729
# To finish the algorithm, we backtrack to the root differentiation chain.
724730
# If the variable already exists in the chain, then we alias them
725731
# (e.g. `x_t ⇒ D(D(z))`), else, we substitute and update `var_to_diff`.
726732
#
727733
# Note that since we always prefer the higher differentiated variable and
728-
# with a tie breaking strategy. The root variable (in this case `z`) is
734+
# with a tie breaking strategy, the root variable (in this case `z`) is
729735
# always uniquely determined. Thus, the result is well-defined.
730736
diff_to_var = invview(var_to_diff)
731737
invag = SimpleDiGraph(nvars)
@@ -735,10 +741,11 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
735741
end
736742
processed = falses(nvars)
737743
iag = InducedAliasGraph(ag, invag, var_to_diff)
738-
newag = AliasGraph(nvars)
744+
dag = AliasGraph(nvars) # alias graph for differentiated variables
739745
newinvag = SimpleDiGraph(nvars)
740-
irreducibles = BitSet()
746+
removed_aliases = BitSet()
741747
updated_diff_vars = Int[]
748+
irreducibles = Int[]
742749
for (v, dv) in enumerate(var_to_diff)
743750
processed[v] && continue
744751
(dv === nothing && diff_to_var[v] === nothing) && continue
@@ -753,19 +760,22 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
753760
nlevels = length(level_to_var)
754761
current_coeff_level = Ref((0, 0))
755762
add_alias! = let current_coeff_level = current_coeff_level,
756-
level_to_var = level_to_var, newag = newag, newinvag = newinvag,
763+
level_to_var = level_to_var, dag = dag, newinvag = newinvag,
757764
processed = processed
758765

759766
v -> begin
760767
coeff, level = current_coeff_level[]
761768
if level + 1 <= length(level_to_var)
762769
av = level_to_var[level + 1]
763770
if v != av # if the level_to_var isn't from the root branch
764-
newag[v] = coeff => av
771+
dag[v] = coeff => av
765772
add_edge!(newinvag, av, v)
766773
end
767774
else
768775
@assert length(level_to_var) == level
776+
if coeff != 1
777+
dag[v] = coeff => v
778+
end
769779
push!(level_to_var, v)
770780
end
771781
mark_processed!(processed, var_to_diff, v)
@@ -788,54 +798,38 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
788798
current_coeff_level[] = coeff, lv
789799
extreme_var(var_to_diff, v, nothing, Val(false), callback = add_alias!)
790800
end
791-
len = length(level_to_var)
792-
len > 1 || continue
793801

794-
set_v_zero! = let newag = newag
795-
v -> newag[v] = 0
802+
@show processed
803+
len = length(level_to_var)
804+
set_v_zero! = let dag = dag
805+
v -> dag[v] = 0
796806
end
807+
zero_av_idx = 0
797808
for (i, av) in enumerate(level_to_var)
798-
has_zero = false
809+
has_zero = iszero(get(ag, av, (1, 0))[1])
810+
push!(removed_aliases, av)
799811
for v in neighbors(newinvag, av)
800-
cv = get(ag, v, nothing)
801-
cv === nothing && continue
802-
c, v = cv
803-
iszero(c) || continue
804-
has_zero = true
805-
# if a chain starts to equal to zero, then all its descendants
806-
# must be zero and reducible
807-
if i < len
808-
# we have `x = 0`
809-
v = level_to_var[i + 1]
810-
extreme_var(var_to_diff, v, nothing, Val(false), callback = set_v_zero!)
811-
end
812-
break
812+
has_zero = has_zero || iszero(get(ag, v, (1, 0))[1])
813+
push!(removed_aliases, v)
813814
end
814-
has_zero && break
815-
816-
# all non-highest order differentiated variables are reducible.
817-
if i == len
818-
# if an irreducible alias appears in only one equation, then
819-
# it's actually not an alias, but a proper equation. E.g.
820-
# D(D(phi)) = a
821-
# D(phi) = sin(t)
822-
# `a` and `D(D(phi))` are not irreducible state. Hence, we need
823-
# to remove `av` from all alias graphs and mark those pairs
824-
# irreducible.
825-
push!(irreducibles, av)
826-
for v in neighbors(newinvag, av)
827-
newag[v] = nothing
828-
push!(irreducibles, v)
829-
end
830-
for v in neighbors(invag, av)
831-
newag[v] = nothing
832-
push!(irreducibles, v)
833-
end
834-
if (cv = get(ag, av, nothing)) !== nothing && !iszero(cv[2])
835-
push!(irreducibles, cv[2])
836-
end
815+
if zero_av_idx == 0 && has_zero
816+
zero_av_idx = i
817+
end
818+
end
819+
# If a chain starts to equal to zero, then all its derivatives must be
820+
# zero. Irreducible variables are highest differentiated variables (with
821+
# order >= 1) that are not zero.
822+
if zero_av_idx > 0
823+
extreme_var(var_to_diff, level_to_var[zero_av_idx], nothing, Val(false), callback = set_v_zero!)
824+
if zero_av_idx > 2
825+
@warn "1"
826+
push!(irreducibles, level_to_var[zero_av_idx - 1])
837827
end
828+
elseif len >= 2
829+
@warn "2"
830+
push!(irreducibles, level_to_var[len])
838831
end
832+
# Handle virtual variables
839833
if nlevels < len
840834
for i in (nlevels + 1):len
841835
li = level_to_var[i]
@@ -844,20 +838,42 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
844838
end
845839
end
846840
end
847-
for (v, (c, a)) in newag
848-
a = a == 0 ? 0 : c * fullvars[a]
849-
@info "differential aliases" fullvars[v] => a
850-
end
851-
@show fullvars[collect(irreducibles)]
852841

853-
if !isempty(irreducibles)
854-
ag = newag
855-
for k in keys(ag)
856-
push!(irreducibles, k)
842+
# Merge dag and ag
843+
freshag = AliasGraph(nvars)
844+
@show irreducibles
845+
@show dag
846+
for (v, (c, a)) in dag
847+
# TODO: make sure that `irreducibles` are
848+
# D(x) ~ D(y) cannot be removed if x and y are not aliases
849+
if v != a && a in irreducibles
850+
push!(removed_aliases, v)
851+
@goto NEXT_ITER
852+
elseif v != a && !iszero(a)
853+
vv = v
854+
aa = a
855+
while true
856+
vv′ = vv
857+
vv = diff_to_var[vv]
858+
vv === nothing && break
859+
if !(haskey(dag, vv) && dag[vv][2] == diff_to_var[aa])
860+
push!(removed_aliases, vv′)
861+
@goto NEXT_ITER
862+
end
863+
end
857864
end
858-
mm_orig2 = isempty(ag) ? mm_orig : reduce!(copy(mm_orig), ag)
859-
mm = simple_aliases!(ag, graph, var_to_diff, mm_orig2, irreducibles)
865+
freshag[v] = c => a
866+
@label NEXT_ITER
867+
end
868+
for (v, (c, a)) in ag
869+
v in removed_aliases && continue
870+
freshag[v] = c => a
871+
end
872+
if freshag != ag
873+
ag = freshag
874+
mm = reduce!(copy(echelon_mm), ag)
860875
end
876+
@info "" echelon_mm mm
861877

862878
for (v, (c, a)) in ag
863879
va = iszero(a) ? a : fullvars[a]
@@ -869,7 +885,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
869885
set_neighbors!(graph, e, mm.row_cols[ei])
870886
end
871887

872-
# because of `irreducibles`, `mm` cannot always be trusted.
873888
return ag, mm, updated_diff_vars
874889
end
875890

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ function build_explicit_observed_function(sys, ts;
331331
eqs_cache[] = Dict(eq.lhs => eq.rhs for eq in equations(sys))
332332
end
333333
eqs_dict = eqs_cache[]
334-
rhs = get(eqs_dict, v, nothing)
334+
rhs_diffeq = get(eqs_dict, v, nothing)
335+
push!(obsexprs, v rhs_diffeq)
335336
if rhs === nothing
336337
error("The observed variable $(eq.lhs) depends on the differentiated variable $v, but it's not explicit solved. Fix file an issue if you are sure that the system is valid.")
337338
end

0 commit comments

Comments
 (0)