Skip to content

Commit 8f7e24b

Browse files
committed
Hey, this kinda works
1 parent 9822ee7 commit 8f7e24b

File tree

1 file changed

+42
-13
lines changed

1 file changed

+42
-13
lines changed

src/structural_transformation/codegen.jl

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_s
9595
sparse(I, J, true)
9696
end
9797

98-
function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDict, assignments, deps, var2assignment; checkbounds=true)
98+
function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDict, assignments, (deps, invdeps), var2assignment; checkbounds=true)
9999
isempty(vars) && throw(ArgumentError("vars may not be empty"))
100100
length(eqs) == length(vars) || throw(ArgumentError("vars must be of the same length as the number of equations to find the roots of"))
101101
rhss = map(x->x.rhs, eqs)
@@ -106,7 +106,6 @@ function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDic
106106
init_assignments = [var2assignment[p] for p in paramset if haskey(var2assignment, p)]
107107
tmp = [init_assignments]
108108
# `deps[init_assignments]` gives the dependency of `init_assignments`
109-
successors = Dict{Int,Vector{Int}}()
110109
while true
111110
next_assignments = reduce(vcat, deps[init_assignments])
112111
isempty(next_assignments) && break
@@ -118,19 +117,43 @@ function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDic
118117

119118
# Compute `params`. They are like enclosed variables
120119
rhsvars = [ModelingToolkit.vars(r.rhs) for r in needed_assignments]
121-
is_vars_independent = isdisjoint.((vars,), rhsvars)
122-
inner_assignments = []; outer_idxs = Int[]
123-
outer_assignments = []; inner_idxs = Int[]
124-
for (i, ind) in enumerate(is_vars_independent)
125-
a = needed_assignments[i]
126-
if ind
127-
push!(outer_assignments, a)
128-
push!(outer_idxs, i)
120+
vars_set = Set(vars)
121+
outer_set = BitSet()
122+
inner_set = BitSet()
123+
for (i, vs) in enumerate(rhsvars)
124+
j = needed_assignments_idxs[i]
125+
if isdisjoint(vars_set, vs)
126+
push!(outer_set, j)
129127
else
130-
push!(inner_assignments, a)
131-
push!(inner_idxs, i)
128+
push!(inner_set, j)
132129
end
133130
end
131+
init_refine = BitSet()
132+
for i in inner_set
133+
union!(init_refine, invdeps[i])
134+
end
135+
intersect!(init_refine, outer_set)
136+
setdiff!(outer_set, init_refine)
137+
union!(inner_set, init_refine)
138+
139+
next_refine = BitSet()
140+
while true
141+
for i in init_refine
142+
id = invdeps[i]
143+
isempty(id) && break
144+
union!(next_refine, id)
145+
end
146+
intersect!(next_refine, outer_set)
147+
isempty(next_refine) && break
148+
setdiff!(outer_set, next_refine)
149+
union!(inner_set, next_refine)
150+
151+
init_refine, next_refine = next_refine, init_refine
152+
empty!(next_refine)
153+
end
154+
global2local = Dict(j=>i for (i, j) in enumerate(needed_assignments_idxs))
155+
inner_idxs = [global2local[i] for i in collect(inner_set)]
156+
outer_idxs = [global2local[i] for i in collect(outer_set)]
134157
extravars = reduce(union!, rhsvars[inner_idxs], init=Set())
135158
union!(paramset, extravars)
136159
setdiff!(paramset, vars)
@@ -227,6 +250,12 @@ function build_torn_function(
227250
mass_matrix_diag = ones(length(states_idxs))
228251

229252
assignments, deps, bf_states = tearing_assignments(sys)
253+
invdeps = map(_->BitSet(), deps)
254+
for (i, d) in enumerate(deps)
255+
for a in d
256+
push!(invdeps[a], i)
257+
end
258+
end
230259
var2assignment = Dict{Any,Int}(eq.lhs => i for (i, eq) in enumerate(assignments))
231260
is_not_prepended_assignment = trues(length(assignments))
232261

@@ -241,7 +270,7 @@ function build_torn_function(
241270
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
242271
isempty(torn_eqs_idxs) && continue
243272
if length(torn_eqs_idxs) <= max_inlining_size
244-
nlsolve_expr = gen_nlsolve!(is_not_prepended_assignment, eqs[torn_eqs_idxs], s.fullvars[torn_vars_idxs], defs, assignments, deps, var2assignment, checkbounds=checkbounds)
273+
nlsolve_expr = gen_nlsolve!(is_not_prepended_assignment, eqs[torn_eqs_idxs], s.fullvars[torn_vars_idxs], defs, assignments, (deps, invdeps), var2assignment, checkbounds=checkbounds)
245274
append!(torn_expr, nlsolve_expr)
246275
push!(nlsolve_scc_idxs, i)
247276
else

0 commit comments

Comments
 (0)