Skip to content

Commit 7f6548c

Browse files
committed
Add BitDict
1 parent 0444f7d commit 7f6548c

File tree

5 files changed

+75
-19
lines changed

5 files changed

+75
-19
lines changed

src/structural_transformation/partial_state_selection.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
240240
# We can eliminate variables that are not a selected state (differential
241241
# variables). Selected states are differentiated variables that are not
242242
# dummy derivatives.
243-
can_eliminate = let var_to_diff = var_to_diff, dummy_derivatives_set = dummy_derivatives_set
243+
can_eliminate = let var_to_diff = var_to_diff,
244+
dummy_derivatives_set = dummy_derivatives_set
244245

245246
v -> begin
246247
dv = var_to_diff[v]

src/structural_transformation/symbolics_tearing.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ function solve_equation(eq, var, simplify)
115115
var ~ rhs
116116
end
117117

118-
function substitute_vars!(graph::BipartiteGraph, subs, cache=Int[], callback! = nothing; exclude = ())
118+
function substitute_vars!(graph::BipartiteGraph, subs, cache = Int[], callback! = nothing;
119+
exclude = ())
119120
for su in subs
120121
su === nothing && continue
121122
v, v′ = su
@@ -361,7 +362,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
361362
add_edge!(solvable_graph, eq_idx, dx_idx)
362363
add_edge!(graph, eq_idx, x_t_idx)
363364
add_edge!(graph, eq_idx, dx_idx)
364-
365365
end
366366
# We use this info to substitute all `D(D(x))` or `D(x_t)` except
367367
# the `D(D(x)) ~ x_tt` equation to `x_tt`.
@@ -382,8 +382,10 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
382382
# substituted to `x_tt`.
383383
for idx in (ogidx, dx_idx)
384384
subidx = ((idx => x_t_idx),)
385-
substitute_vars!(graph, subidx, idx_buffer, sub_callback!; exclude = order_lowering_eqs)
386-
substitute_vars!(solvable_graph, subidx, idx_buffer; exclude = order_lowering_eqs)
385+
substitute_vars!(graph, subidx, idx_buffer, sub_callback!;
386+
exclude = order_lowering_eqs)
387+
substitute_vars!(solvable_graph, subidx, idx_buffer;
388+
exclude = order_lowering_eqs)
387389
end
388390
end
389391
empty!(subinfo)
@@ -453,7 +455,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
453455
push!(removed_eqs, ieq)
454456
push!(removed_vars, iv)
455457
else
456-
rhs = -b/a
458+
rhs = -b / a
457459
neweq = var ~ simplify ? Symbolics.simplify(rhs) : rhs
458460
push!(subeqs, neweq)
459461
push!(solved_equations, ieq)
@@ -497,7 +499,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
497499

498500
# Update system
499501
solved_variables_set = BitSet(solved_variables)
500-
active_vars = setdiff!(setdiff(BitSet(1:length(fullvars)), solved_variables_set), removed_vars)
502+
active_vars = setdiff!(setdiff(BitSet(1:length(fullvars)), solved_variables_set),
503+
removed_vars)
501504
new_var_to_diff = complete(DiffGraph(length(active_vars)))
502505
idx = 0
503506
for (v, d) in enumerate(var_to_diff)

src/systems/alias_elimination.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ function aag_bareiss(sys::AbstractSystem)
1919
return aag_bareiss!(state.structure.graph, complete(state.structure.var_to_diff), mm)
2020
end
2121

22-
function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true); callback = _ -> nothing) where descend
22+
function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true);
23+
callback = _ -> nothing) where {descend}
2324
g = descend ? invview(var_to_diff) : var_to_diff
2425
callback(v)
2526
while (v′ = g[v]) !== nothing
@@ -32,17 +33,20 @@ function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true)
3233
level === nothing ? v : (v => level)
3334
end
3435

35-
function walk_to_root!(iag, v::Integer, level = 0)
36+
function walk_to_root!(relative_level, iag, v::Integer, level = 0)
3637
brs = neighbors(iag, v)
3738
min_var_level = v => level
3839
for (x, lv′) in brs
3940
lv = lv′ + level
40-
x, lv = walk_to_root!(iag, x, lv)
41+
x, lv = walk_to_root!(relative_level, iag, x, lv)
42+
relative_level[x] = lv
4143
if min_var_level[2] > lv
4244
min_var_level = x => lv
4345
end
4446
end
45-
return extreme_var(iag.var_to_diff, min_var_level...)
47+
x, lv = extreme_var(iag.var_to_diff, min_var_level...)
48+
relative_level[x] = lv
49+
return x => lv
4650
end
4751

4852
function alias_elimination(sys)
@@ -87,22 +91,28 @@ function alias_elimination(sys)
8791
Main._a[] = ag, invag
8892
processed = falses(nvars)
8993
iag = InducedAliasGraph(ag, invag, var_to_diff, processed)
94+
relative_level = BitDict(nvars)
9095
newag = AliasGraph(nvars)
9196
for (v, dv) in enumerate(var_to_diff)
9297
processed[v] && continue
9398
(dv === nothing && diff_to_var[v] === nothing) && continue
9499

95100
# TODO: use an iterator, and get a relative level vector for `processed`
96101
# variabels.
97-
r, lv = walk_to_root!(iag, v)
102+
r, lv = walk_to_root!(relative_level, iag, v)
98103
fill!(processed, false)
99104
#lv = extreme_var(var_to_diff, v, -lv, Val(false))
100-
lv′ = extreme_var(var_to_diff, v, 0, Val(false))[2]
105+
lv′ = extreme_var(var_to_diff, v, 0)[2]
101106
let
102107
sv = fullvars[v]
103108
root = fullvars[r]
104-
@warn "" sv => root level = lv levelv = lv′
109+
@warn "" sv=>root level=lv levelv=lv′
110+
for (v, rl) in pairs(relative_level)
111+
@show v, rl
112+
@show fullvars[v], rl - lv, rl, lv
113+
end
105114
end
115+
empty!(relative_level)
106116
level_to_var = Int[r]
107117
v′′::Union{Nothing, Int} = v′::Int = r
108118
while (v′′ = var_to_diff[v′]) !== nothing
@@ -113,7 +123,7 @@ function alias_elimination(sys)
113123
if nlevels < (new_nlevels = length(level_to_var))
114124
@assert !(D isa Nothing)
115125
for i in (nlevels + 1):new_nlevels
116-
var_to_diff[level_to_var[i-1]] = level_to_var[i]
126+
var_to_diff[level_to_var[i - 1]] = level_to_var[i]
117127
fullvars[level_to_var[i]] = D(fullvars[level_to_var[i - 1]])
118128
end
119129
end
@@ -147,7 +157,7 @@ function alias_elimination(sys)
147157
end
148158

149159
newstates = []
150-
for j in eachindex(fullvars)
160+
for j in eachindex(fullvars)
151161
if j in keys(ag)
152162
_, var = ag[j]
153163
iszero(var) && continue
@@ -316,7 +326,9 @@ struct InducedAliasGraph
316326
visited::BitVector
317327
end
318328

319-
InducedAliasGraph(ag, invag, var_to_diff) = InducedAliasGraph(ag, invag, var_to_diff, falses(nv(invag)))
329+
function InducedAliasGraph(ag, invag, var_to_diff)
330+
InducedAliasGraph(ag, invag, var_to_diff, falses(nv(invag)))
331+
end
320332

321333
struct IAGNeighbors
322334
iag::InducedAliasGraph
@@ -540,7 +552,8 @@ function locally_structure_simplify!(adj_row, pivot_var, ag, var_to_diff)
540552
# loop.
541553
#
542554
# We're relying on `var` being produced in sorted order here.
543-
nirreducible += !(alias_candidate isa Pair) || alias_var != alias_candidate[2]
555+
nirreducible += !(alias_candidate isa Pair) ||
556+
alias_var != alias_candidate[2]
544557
alias_candidate = new_coeff => alias_var
545558
end
546559
end

src/systems/systemstructure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ struct DiffChainIterator{Descend}
147147
v::Int
148148
end
149149

150-
function Base.iterate(di::DiffChainIterator{Descend}, v=nothing) where Descend
150+
function Base.iterate(di::DiffChainIterator{Descend}, v = nothing) where {Descend}
151151
if v === nothing
152152
vv = di.v
153153
return (vv, vv)

src/utils.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,42 @@ function promote_to_concrete(vs; tofloat = true, use_union = false)
590590
convert.(C, vs)
591591
end
592592
end
593+
594+
struct BitDict <: AbstractDict{Int, Int}
595+
keys::Vector{Int}
596+
values::Vector{Union{Nothing, Int}}
597+
end
598+
BitDict(n::Integer) = BitDict(Int[], Union{Nothing, Int}[nothing for _ in 1:n])
599+
struct BitDictKeySet <: AbstractSet{Int}
600+
d::BitDict
601+
end
602+
603+
Base.keys(d::BitDict) = BitDictKeySet(d)
604+
Base.in(v::Integer, s::BitDictKeySet) = s.d.values[v] !== nothing
605+
Base.iterate(s::BitDictKeySet, state...) = iterate(s.d.keys, state...)
606+
function Base.setindex!(d::BitDict, val::Integer, ind::Integer)
607+
if 1 <= ind <= length(d.values) && d.values[ind] === nothing
608+
push!(d.keys, ind)
609+
end
610+
d.values[ind] = val
611+
end
612+
function Base.getindex(d::BitDict, ind::Integer)
613+
if 1 <= ind <= length(d.values) && d.values[ind] === nothing
614+
return d.values[ind]
615+
else
616+
throw(KeyError(ind))
617+
end
618+
end
619+
function Base.iterate(d::BitDict, state...)
620+
r = Base.iterate(d.keys, state...)
621+
r === nothing && return nothing
622+
k, state = r
623+
(k => d.values[k]), state
624+
end
625+
function Base.empty!(d::BitDict)
626+
for v in d.keys
627+
d.values[v] = nothing
628+
end
629+
empty!(d.keys)
630+
d
631+
end

0 commit comments

Comments
 (0)