Skip to content

Commit 9822ee7

Browse files
committed
WIP
1 parent 7015009 commit 9822ee7

File tree

1 file changed

+42
-21
lines changed

1 file changed

+42
-21
lines changed

src/structural_transformation/codegen.jl

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -95,25 +95,47 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_s
9595
sparse(I, J, true)
9696
end
9797

98-
function gen_nlsolve(eqs, vars, u0map::AbstractDict, assignments, deps, var2assignment; checkbounds=true)
98+
function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDict, assignments, deps, var2assignment; checkbounds=true)
9999
isempty(vars) && throw(ArgumentError("vars may not be empty"))
100100
length(eqs) == length(vars) || throw(ArgumentError("vars must be of the same length as the number of equations to find the roots of"))
101101
rhss = map(x->x.rhs, eqs)
102102
# We use `vars` instead of `graph` to capture parameters, too.
103-
paramset = Set{Any}(Iterators.flatten(ModelingToolkit.vars(r) for r in rhss))
103+
paramset = ModelingToolkit.vars(r for r in rhss)
104104

105+
# Compute necessary assignments for the nlsolve expr
105106
init_assignments = [var2assignment[p] for p in paramset if haskey(var2assignment, p)]
106107
tmp = [init_assignments]
107108
# `deps[init_assignments]` gives the dependency of `init_assignments`
108-
while (next_assignments = reduce(vcat, deps[init_assignments]); !isempty(next_assignments))
109+
successors = Dict{Int,Vector{Int}}()
110+
while true
111+
next_assignments = reduce(vcat, deps[init_assignments])
112+
isempty(next_assignments) && break
109113
init_assignments = next_assignments
110114
push!(tmp, init_assignments)
111115
end
112-
needed_assignments = mapreduce(i->assignments[i], vcat, unique(reverse(tmp)))
113-
extravars = Set{Any}(Iterators.flatten(ModelingToolkit.vars(r.rhs) for r in needed_assignments))
116+
needed_assignments_idxs = reduce(vcat, unique(reverse(tmp)))
117+
needed_assignments = assignments[needed_assignments_idxs]
118+
119+
# Compute `params`. They are like enclosed variables
120+
rhsvars = [ModelingToolkit.vars(r.rhs) for r in needed_assignments]
121+
is_vars_independent = isdisjoint.((vars,), rhsvars)
122+
inner_assignments = []; outer_idxs = Int[]
123+
outer_assignments = []; inner_idxs = Int[]
124+
for (i, ind) in enumerate(is_vars_independent)
125+
a = needed_assignments[i]
126+
if ind
127+
push!(outer_assignments, a)
128+
push!(outer_idxs, i)
129+
else
130+
push!(inner_assignments, a)
131+
push!(inner_idxs, i)
132+
end
133+
end
134+
extravars = reduce(union!, rhsvars[inner_idxs], init=Set())
114135
union!(paramset, extravars)
115-
# these are not the subject of the root finding
116-
setdiff!(paramset, vars); setdiff!(paramset, map(a->a.lhs, needed_assignments))
136+
setdiff!(paramset, vars)
137+
setdiff!(paramset, [needed_assignments[i].lhs for i in inner_idxs])
138+
union!(paramset, [needed_assignments[i].lhs for i in outer_idxs])
117139
params = collect(paramset)
118140

119141
# splatting to tighten the type
@@ -141,7 +163,7 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict, assignments, deps, var2assi
141163
],
142164
[],
143165
Let(
144-
needed_assignments,
166+
needed_assignments[inner_idxs],
145167
isscalar ? rhss[1] : MakeArray(rhss, SVector)
146168
)
147169
) |> SymbolicUtils.Code.toexpr
@@ -157,7 +179,16 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict, assignments, deps, var2assi
157179
)
158180
end)
159181

182+
preassignments = []
183+
for i in outer_idxs
184+
ii = needed_assignments_idxs[i]
185+
is_not_prepended_assignment[ii] || continue
186+
is_not_prepended_assignment[ii] = false
187+
push!(preassignments, assignments[ii])
188+
end
189+
160190
nlsolve_expr = Assignment[
191+
preassignments
161192
fname @RuntimeGeneratedFunction(f)
162193
DestructuredArgs(vars, inbounds=!checkbounds) solver_call
163194
]
@@ -197,6 +228,7 @@ function build_torn_function(
197228

198229
assignments, deps, bf_states = tearing_assignments(sys)
199230
var2assignment = Dict{Any,Int}(eq.lhs => i for (i, eq) in enumerate(assignments))
231+
is_not_prepended_assignment = trues(length(assignments))
200232

201233
torn_expr = Assignment[]
202234

@@ -209,18 +241,7 @@ function build_torn_function(
209241
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
210242
isempty(torn_eqs_idxs) && continue
211243
if length(torn_eqs_idxs) <= max_inlining_size
212-
nlsolve_expr = gen_nlsolve(eqs[torn_eqs_idxs], s.fullvars[torn_vars_idxs], defs, assignments, deps, var2assignment, checkbounds=checkbounds)
213-
#=
214-
# a temporary vector that we need to reverse to get the correct
215-
# dependency evaluation order.
216-
local_deps = Vector{Int}[]
217-
init_deps = [var2assignment[p] for p in params if haskey(var2assignment, p)]
218-
push!(local_deps, init_deps)
219-
while (next_deps = reduce(vcat, deps[init_deps]); !isempty(next_deps))
220-
init_deps = next_deps
221-
push!(local_deps, init_deps)
222-
end
223-
=#
244+
nlsolve_expr = gen_nlsolve!(is_not_prepended_assignment, eqs[torn_eqs_idxs], s.fullvars[torn_vars_idxs], defs, assignments, deps, var2assignment, checkbounds=checkbounds)
224245
append!(torn_expr, nlsolve_expr)
225246
push!(nlsolve_scc_idxs, i)
226247
else
@@ -256,7 +277,7 @@ function build_torn_function(
256277
],
257278
[],
258279
pre(Let(
259-
[torn_expr; assignments],
280+
[torn_expr; assignments[is_not_prepended_assignment]],
260281
funbody
261282
))
262283
),

0 commit comments

Comments
 (0)