Skip to content

Commit e15c650

Browse files
authored
Merge pull request #1888 from SciML/myb/faster_closure
Use faster data structure for the transitive closure calculation and fast symbolic substitution
2 parents 3ee1f87 + f934c96 commit e15c650

File tree

6 files changed

+103
-36
lines changed

6 files changed

+103
-36
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
3636
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
3737
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
3838
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
39-
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
4039
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4140
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4241
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -73,7 +72,6 @@ Reexport = "0.2, 1"
7372
RuntimeGeneratedFunctions = "0.4.3, 0.5"
7473
SciMLBase = "1.58.0"
7574
Setfield = "0.7, 0.8, 1"
76-
SimpleWeightedGraphs = "1"
7775
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
7876
StaticArrays = "0.10, 0.11, 0.12, 1.0"
7977
SymbolicUtils = "0.19"

src/inputoutput.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,17 +119,17 @@ function same_or_inner_namespace(u, var)
119119
nv = get_namespace(var)
120120
nu == nv || # namespaces are the same
121121
startswith(nv, nu) || # or nv starts with nu, i.e., nv is an inner namepsace to nu
122-
occursin('', string(Symbolics.getname(var))) &&
123-
!occursin('', string(Symbolics.getname(u))) # or u is top level but var is internal
122+
occursin('', string(getname(var))) &&
123+
!occursin('', string(getname(u))) # or u is top level but var is internal
124124
end
125125

126126
function inner_namespace(u, var)
127127
nu = get_namespace(u)
128128
nv = get_namespace(var)
129129
nu == nv && return false
130130
startswith(nv, nu) || # or nv starts with nu, i.e., nv is an inner namepsace to nu
131-
occursin('', string(Symbolics.getname(var))) &&
132-
!occursin('', string(Symbolics.getname(u))) # or u is top level but var is internal
131+
occursin('', string(getname(var))) &&
132+
!occursin('', string(getname(u))) # or u is top level but var is internal
133133
end
134134

135135
"""
@@ -138,7 +138,7 @@ end
138138
Return the namespace of a variable as a string. If the variable is not namespaced, the string is empty.
139139
"""
140140
function get_namespace(x)
141-
sname = string(Symbolics.getname(x))
141+
sname = string(getname(x))
142142
parts = split(sname, '')
143143
if length(parts) == 1
144144
return ""

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2222
get_postprocess_fbody, vars!,
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2424
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
25-
AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL
25+
AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
26+
fast_substitute
2627

2728
using ModelingToolkit.BipartiteGraphs
2829
import .BipartiteGraphs: invview, complete

src/structural_transformation/symbolics_tearing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
227227
idx_buffer = Int[]
228228
sub_callback! = let eqs = neweqs, fullvars = fullvars
229229
(ieq, s) -> begin
230-
neweq = substitute(eqs[ieq], fullvars[s[1]] => fullvars[s[2]])
230+
neweq = fast_substitute(eqs[ieq], fullvars[s[1]] => fullvars[s[2]])
231231
eqs[ieq] = neweq
232232
end
233233
end
@@ -282,7 +282,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
282282
end
283283
for eq in 𝑑neighbors(graph, dv)
284284
dummy_sub[dd] = v_t
285-
neweqs[eq] = substitute(neweqs[eq], dd => v_t)
285+
neweqs[eq] = fast_substitute(neweqs[eq], dd => v_t)
286286
end
287287
fullvars[dv] = v_t
288288
# If we have:
@@ -295,7 +295,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
295295
while (ddx = var_to_diff[dx]) !== nothing
296296
dx_t = D(x_t)
297297
for eq in 𝑑neighbors(graph, ddx)
298-
neweqs[eq] = substitute(neweqs[eq], fullvars[ddx] => dx_t)
298+
neweqs[eq] = fast_substitute(neweqs[eq], fullvars[ddx] => dx_t)
299299
end
300300
fullvars[ddx] = dx_t
301301
dx = ddx
@@ -655,7 +655,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
655655
obs_sub[eq.lhs] = eq.rhs
656656
end
657657
# TODO: compute the dependency correctly so that we don't have to do this
658-
obs = substitute.([oldobs; subeqs], (obs_sub,))
658+
obs = fast_substitute([oldobs; subeqs], obs_sub)
659659
@set! sys.observed = obs
660660
@set! state.sys = sys
661661
@set! sys.tearing_state = state

src/systems/alias_elimination.jl

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using SymbolicUtils: Rewriters
2-
using SimpleWeightedGraphs
32
using Graphs.Experimental.Traversals
43

54
const KEEP = typemin(Int)
@@ -153,7 +152,8 @@ function alias_elimination!(state::TearingState; kwargs...)
153152
end
154153
end
155154
for ieq in eqs_to_update
156-
eqs[ieq] = substitute(eqs[ieq], subs)
155+
eq = eqs[ieq]
156+
eqs[ieq] = fast_substitute(eq, subs)
157157
end
158158

159159
for old_ieq in to_expand
@@ -365,9 +365,33 @@ function Base.in(i::Int, agk::AliasGraphKeySet)
365365
1 <= i <= length(aliasto) && aliasto[i] !== nothing
366366
end
367367

368+
canonicalize(a, b) = a <= b ? (a, b) : (b, a)
369+
struct WeightedGraph{T, W} <: AbstractGraph{T}
370+
graph::SimpleGraph{T}
371+
dict::Dict{Tuple{T, T}, W}
372+
end
373+
function WeightedGraph{T, W}(n) where {T, W}
374+
WeightedGraph{T, W}(SimpleGraph{T}(n), Dict{Tuple{T, T}, W}())
375+
end
376+
377+
function Graphs.add_edge!(g::WeightedGraph, u, v, w)
378+
r = add_edge!(g.graph, u, v)
379+
r && (g.dict[canonicalize(u, v)] = w)
380+
r
381+
end
382+
Graphs.has_edge(g::WeightedGraph, u, v) = has_edge(g.graph, u, v)
383+
Graphs.ne(g::WeightedGraph) = ne(g.graph)
384+
Graphs.nv(g::WeightedGraph) = nv(g.graph)
385+
get_weight(g::WeightedGraph, u, v) = g.dict[canonicalize(u, v)]
386+
Graphs.is_directed(::Type{<:WeightedGraph}) = false
387+
Graphs.inneighbors(g::WeightedGraph, v) = inneighbors(g.graph, v)
388+
Graphs.outneighbors(g::WeightedGraph, v) = outneighbors(g.graph, v)
389+
Graphs.vertices(g::WeightedGraph) = vertices(g.graph)
390+
Graphs.edges(g::WeightedGraph) = vertices(g.graph)
391+
368392
function equality_diff_graph(ag::AliasGraph, var_to_diff::DiffGraph)
369393
g = SimpleDiGraph{Int}(length(var_to_diff))
370-
eqg = SimpleWeightedGraph{Int, Int}(length(var_to_diff))
394+
eqg = WeightedGraph{Int, Int}(length(var_to_diff))
371395
zero_vars = Int[]
372396
for (v, (c, a)) in ag
373397
if iszero(a)
@@ -378,7 +402,6 @@ function equality_diff_graph(ag::AliasGraph, var_to_diff::DiffGraph)
378402
add_edge!(g, a, v)
379403

380404
add_edge!(eqg, v, a, c)
381-
add_edge!(eqg, a, v, c)
382405
end
383406
transitiveclosure!(g)
384407
weighted_transitiveclosure!(eqg)
@@ -394,9 +417,14 @@ end
394417
function weighted_transitiveclosure!(g)
395418
cps = connected_components(g)
396419
for cp in cps
397-
for k in cp, i in cp, j in cp
398-
(has_edge(g, i, k) && has_edge(g, k, j)) || continue
399-
add_edge!(g, i, j, get_weight(g, i, k) * get_weight(g, k, j))
420+
n = length(cp)
421+
for k in cp
422+
for i′ in 1:n, j′ in (i′ + 1):n
423+
i = cp[i′]
424+
j = cp[j′]
425+
(has_edge(g, i, k) && has_edge(g, k, j)) || continue
426+
add_edge!(g, i, j, get_weight(g, i, k) * get_weight(g, k, j))
427+
end
400428
end
401429
end
402430
return g
@@ -670,11 +698,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
670698
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
671699
end
672700
diff_aliases = Vector{Pair{Int, Int}}[]
673-
stem = Int[]
701+
stems = Vector{Int}[]
674702
stem_set = BitSet()
675703
for (v, dv) in enumerate(var_to_diff)
676704
processed[v] && continue
677705
(dv === nothing && diff_to_var[v] === nothing) && continue
706+
stem = Int[]
678707
r = find_root!(dls, g, v)
679708
prev_r = -1
680709
for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop
@@ -714,9 +743,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
714743
push!(stem_set, prev_r)
715744
push!(stem, prev_r)
716745
push!(diff_aliases, reach₌)
717-
for (_, v) in reach₌
746+
for (c, v) in reach₌
718747
v == prev_r && continue
719-
add_edge!(eqg, v, prev_r)
748+
add_edge!(eqg, v, prev_r, c)
720749
end
721750
end
722751

@@ -729,9 +758,24 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
729758
dag[v] = c => a
730759
end
731760
end
732-
# Obtain transitive closure after completing the alias edges from diff
733-
# edges.
734-
weighted_transitiveclosure!(eqg)
761+
push!(stems, stem)
762+
763+
# clean up
764+
for v in dls.visited
765+
dls.dists[v] = typemax(Int)
766+
processed[v] = true
767+
end
768+
empty!(dls.visited)
769+
empty!(diff_aliases)
770+
empty!(stem_set)
771+
end
772+
773+
# Obtain transitive closure after completing the alias edges from diff
774+
# edges. As a performance optimization, we only compute the transitive
775+
# closure once at the very end.
776+
weighted_transitiveclosure!(eqg)
777+
zero_vars_set = BitSet()
778+
for stem in stems
735779
# Canonicalize by preferring the lower differentiated variable
736780
# If we have the system
737781
# ```
@@ -780,7 +824,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
780824
# x := 0
781825
# y := 0
782826
# ```
783-
zero_vars_set = BitSet()
784827
for v in zero_vars
785828
for a in Iterators.flatten((v, outneighbors(eqg, v)))
786829
while true
@@ -803,17 +846,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
803846
dag[v] = 0
804847
end
805848
end
806-
807-
# clean up
808-
for v in dls.visited
809-
dls.dists[v] = typemax(Int)
810-
processed[v] = true
811-
end
812-
empty!(dls.visited)
813-
empty!(diff_aliases)
814-
empty!(stem)
815-
empty!(stem_set)
849+
empty!(zero_vars_set)
816850
end
851+
817852
# update `dag`
818853
for k in keys(dag)
819854
dag[k]

src/utils.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,3 +741,36 @@ function jacobian_wrt_vars(pf::F, p, input_idxs, chunk::C) where {F, C}
741741
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag)
742742
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))
743743
end
744+
745+
# Symbolics needs to call unwrap on the substitution rules, but most of the time
746+
# we don't want to do that in MTK.
747+
function fast_substitute(eq::Equation, subs)
748+
fast_substitute(eq.lhs, subs) ~ fast_substitute(eq.rhs, subs)
749+
end
750+
function fast_substitute(eq::Equation, subs::Pair)
751+
fast_substitute(eq.lhs, subs) ~ fast_substitute(eq.rhs, subs)
752+
end
753+
fast_substitute(eqs::AbstractArray{Equation}, subs) = fast_substitute.(eqs, (subs,))
754+
fast_substitute(a, b) = substitute(a, b)
755+
function fast_substitute(expr, pair::Pair)
756+
a, b = pair
757+
isequal(expr, a) && return b
758+
759+
istree(expr) || return expr
760+
op = fast_substitute(operation(expr), pair)
761+
canfold = Ref(!(op isa Symbolic))
762+
args = let canfold = canfold
763+
map(SymbolicUtils.unsorted_arguments(expr)) do x
764+
x′ = fast_substitute(x, pair)
765+
canfold[] = canfold[] && !(x′ isa Symbolic)
766+
x′
767+
end
768+
end
769+
canfold[] && return op(args...)
770+
771+
similarterm(expr,
772+
op,
773+
args,
774+
symtype(expr);
775+
metadata = metadata(expr))
776+
end

0 commit comments

Comments
 (0)