Skip to content

Commit d4e53db

Browse files
refactor: port SCCNonlinearProblem to separate file
1 parent 7a5c7a8 commit d4e53db

File tree

1 file changed

+246
-0
lines changed

1 file changed

+246
-0
lines changed
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
const TypeT = Union{DataType, UnionAll}
2+
3+
struct CacheWriter{F}
4+
fn::F
5+
end
6+
7+
function (cw::CacheWriter)(p, sols)
8+
cw.fn(p.caches, sols, p)
9+
end
10+
11+
function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT},
12+
exprs::Dict{TypeT, Vector{Any}}, solsyms, obseqs::Vector{Equation};
13+
eval_expression = false, eval_module = @__MODULE__, cse = true)
14+
ps = parameters(sys; initial_parameters = true)
15+
rps = reorder_parameters(sys, ps)
16+
obs_assigns = [eq.lhs eq.rhs for eq in obseqs]
17+
body = map(eachindex(buffer_types), buffer_types) do i, T
18+
Symbol(:tmp, i) SetArray(true, :(out[$i]), get(exprs, T, []))
19+
end
20+
21+
function argument_name(i::Int)
22+
if i <= length(solsyms)
23+
return :($(generated_argument_name(1))[$i])
24+
end
25+
return generated_argument_name(i - length(solsyms))
26+
end
27+
array_assignments = array_variable_assignments(solsyms...; argument_name)
28+
fn = build_function_wrapper(
29+
sys, nothing, :out,
30+
DestructuredArgs(DestructuredArgs.(solsyms), generated_argument_name(1)),
31+
rps...; p_start = 3, p_end = length(rps) + 2,
32+
expression = Val{true}, add_observed = false, cse,
33+
extra_assignments = [array_assignments; obs_assigns; body])
34+
fn = eval_or_rgf(fn; eval_expression, eval_module)
35+
fn = GeneratedFunctionWrapper{(3, 3, is_split(sys))}(fn, nothing)
36+
return CacheWriter(fn)
37+
end
38+
39+
struct SCCNonlinearFunction{iip} end
40+
41+
function SCCNonlinearFunction{iip}(
42+
sys::System, _eqs, _dvs, _obs, cachesyms; eval_expression = false,
43+
eval_module = @__MODULE__, cse = true, kwargs...) where {iip}
44+
ps = parameters(sys; initial_parameters = true)
45+
rps = reorder_parameters(sys, ps)
46+
47+
obs_assignments = [eq.lhs eq.rhs for eq in _obs]
48+
49+
rhss = [eq.rhs - eq.lhs for eq in _eqs]
50+
f_gen = build_function_wrapper(sys,
51+
rhss, _dvs, rps..., cachesyms...; p_start = 2,
52+
p_end = length(rps) + length(cachesyms) + 1, add_observed = false,
53+
extra_assignments = obs_assignments, expression = Val{true}, cse)
54+
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
55+
f = GeneratedFunctionWrapper{(2, 2, is_split(sys))}(f_oop, f_iip)
56+
57+
subsys = NonlinearSystem(_eqs, _dvs, ps; observed = _obs,
58+
parameter_dependencies = parameter_dependencies(sys), name = nameof(sys))
59+
if get_index_cache(sys) !== nothing
60+
@set! subsys.index_cache = subset_unknowns_observed(
61+
get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
62+
@set! subsys.complete = true
63+
end
64+
65+
return NonlinearFunction{iip}(f; sys = subsys)
66+
end
67+
68+
function SciMLBase.SCCNonlinearProblem(sys::NonlinearSystem, args...; kwargs...)
69+
SCCNonlinearProblem{true}(sys, args...; kwargs...)
70+
end
71+
72+
function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
73+
parammap = SciMLBase.NullParameters(); eval_expression = false, eval_module = @__MODULE__,
74+
cse = true, kwargs...) where {iip}
75+
if !iscomplete(sys) || get_tearing_state(sys) === nothing
76+
error("A simplified `NonlinearSystem` is required. Call `structural_simplify` on the system before creating an `SCCNonlinearProblem`.")
77+
end
78+
79+
if !is_split(sys)
80+
error("The system has been simplified with `split = false`. `SCCNonlinearProblem` is not compatible with this system. Pass `split = true` to `structural_simplify` to use `SCCNonlinearProblem`.")
81+
end
82+
83+
ts = get_tearing_state(sys)
84+
var_eq_matching, var_sccs = StructuralTransformations.algebraic_variables_scc(ts)
85+
86+
if length(var_sccs) == 1
87+
return NonlinearProblem{iip}(
88+
sys, u0map, parammap; eval_expression, eval_module, kwargs...)
89+
end
90+
91+
condensed_graph = MatchedCondensationGraph(
92+
DiCMOBiGraph{true}(complete(ts.structure.graph),
93+
complete(var_eq_matching)),
94+
var_sccs)
95+
toporder = topological_sort_by_dfs(condensed_graph)
96+
var_sccs = var_sccs[toporder]
97+
eq_sccs = map(Base.Fix1(getindex, var_eq_matching), var_sccs)
98+
99+
dvs = unknowns(sys)
100+
ps = parameters(sys)
101+
eqs = equations(sys)
102+
obs = observed(sys)
103+
104+
_, u0, p = process_SciMLProblem(
105+
EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module, kwargs...)
106+
107+
explicitfuns = []
108+
nlfuns = []
109+
prevobsidxs = BlockArray(undef_blocks, Vector{Int}, Int[])
110+
# Cache buffer types and corresponding sizes. Stored as a pair of arrays instead of a
111+
# dict to maintain a consistent order of buffers across SCCs
112+
cachetypes = TypeT[]
113+
cachesizes = Int[]
114+
# explicitfun! related information for each SCC
115+
# We need to compute buffer sizes before doing any codegen
116+
scc_cachevars = Dict{TypeT, Vector{Any}}[]
117+
scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
118+
scc_eqs = Vector{Equation}[]
119+
scc_obs = Vector{Equation}[]
120+
# variables solved in previous SCCs
121+
available_vars = Set()
122+
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
123+
# subset unknowns and equations
124+
_dvs = dvs[vscc]
125+
_eqs = eqs[escc]
126+
# get observed equations required by this SCC
127+
union!(available_vars, _dvs)
128+
obsidxs = observed_equations_used_by(sys, _eqs; available_vars)
129+
# the ones used by previous SCCs can be precomputed into the cache
130+
setdiff!(obsidxs, prevobsidxs)
131+
_obs = obs[obsidxs]
132+
union!(available_vars, getproperty.(_obs, (:lhs,)))
133+
134+
# get all subexpressions in the RHS which we can precompute in the cache
135+
# precomputed subexpressions should not contain `banned_vars`
136+
banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,))))
137+
state = Dict()
138+
for i in eachindex(_obs)
139+
_obs[i] = _obs[i].lhs ~ subexpressions_not_involving_vars!(
140+
_obs[i].rhs, banned_vars, state)
141+
end
142+
for i in eachindex(_eqs)
143+
_eqs[i] = _eqs[i].lhs ~ subexpressions_not_involving_vars!(
144+
_eqs[i].rhs, banned_vars, state)
145+
end
146+
147+
# map from symtype to cached variables and their expressions
148+
cachevars = Dict{Union{DataType, UnionAll}, Vector{Any}}()
149+
cacheexprs = Dict{Union{DataType, UnionAll}, Vector{Any}}()
150+
# observed of previous SCCs are in the cache
151+
# NOTE: When we get proper CSE, we can substitute these
152+
# and then use `subexpressions_not_involving_vars!`
153+
for i in prevobsidxs
154+
T = symtype(obs[i].lhs)
155+
buf = get!(() -> Any[], cachevars, T)
156+
push!(buf, obs[i].lhs)
157+
158+
buf = get!(() -> Any[], cacheexprs, T)
159+
push!(buf, obs[i].lhs)
160+
end
161+
162+
for (k, v) in state
163+
k = unwrap(k)
164+
v = unwrap(v)
165+
T = symtype(k)
166+
buf = get!(() -> Any[], cachevars, T)
167+
push!(buf, v)
168+
buf = get!(() -> Any[], cacheexprs, T)
169+
push!(buf, k)
170+
end
171+
172+
# update the sizes of cache buffers
173+
for (T, buf) in cachevars
174+
idx = findfirst(isequal(T), cachetypes)
175+
if idx === nothing
176+
push!(cachetypes, T)
177+
push!(cachesizes, 0)
178+
idx = lastindex(cachetypes)
179+
end
180+
cachesizes[idx] = max(cachesizes[idx], length(buf))
181+
end
182+
183+
push!(scc_cachevars, cachevars)
184+
push!(scc_cacheexprs, cacheexprs)
185+
push!(scc_eqs, _eqs)
186+
push!(scc_obs, _obs)
187+
blockpush!(prevobsidxs, obsidxs)
188+
end
189+
190+
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
191+
_dvs = dvs[vscc]
192+
_eqs = scc_eqs[i]
193+
_prevobsidxs = reduce(vcat, blocks(prevobsidxs)[1:(i - 1)]; init = Int[])
194+
_obs = scc_obs[i]
195+
cachevars = scc_cachevars[i]
196+
cacheexprs = scc_cacheexprs[i]
197+
available_vars = [dvs[reduce(vcat, var_sccs[1:(i - 1)]; init = Int[])];
198+
getproperty.(
199+
reduce(vcat, scc_obs[1:(i - 1)]; init = []), (:lhs,))]
200+
_prevobsidxs = vcat(_prevobsidxs,
201+
observed_equations_used_by(
202+
sys, reduce(vcat, values(cacheexprs); init = []); available_vars))
203+
if isempty(cachevars)
204+
push!(explicitfuns, Returns(nothing))
205+
else
206+
solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1)))
207+
push!(explicitfuns,
208+
CacheWriter(sys, cachetypes, cacheexprs, solsyms, obs[_prevobsidxs];
209+
eval_expression, eval_module, cse))
210+
end
211+
212+
cachebufsyms = Tuple(map(cachetypes) do T
213+
get(cachevars, T, [])
214+
end)
215+
f = SCCNonlinearFunction{iip}(
216+
sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, cse, kwargs...)
217+
push!(nlfuns, f)
218+
end
219+
220+
if !isempty(cachetypes)
221+
templates = map(cachetypes, cachesizes) do T, n
222+
# Real refers to `eltype(u0)`
223+
if T == Real
224+
T = eltype(u0)
225+
elseif T <: Array && eltype(T) == Real
226+
T = Array{eltype(u0), ndims(T)}
227+
end
228+
BufferTemplate(T, n)
229+
end
230+
p = rebuild_with_caches(p, templates...)
231+
end
232+
233+
subprobs = []
234+
for (f, vscc) in zip(nlfuns, var_sccs)
235+
prob = NonlinearProblem(f, u0[vscc], p)
236+
push!(subprobs, prob)
237+
end
238+
239+
new_dvs = dvs[reduce(vcat, var_sccs)]
240+
new_eqs = eqs[reduce(vcat, eq_sccs)]
241+
@set! sys.unknowns = new_dvs
242+
@set! sys.eqs = new_eqs
243+
@set! sys.index_cache = subset_unknowns_observed(
244+
get_index_cache(sys), sys, new_dvs, getproperty.(obs, (:lhs,)))
245+
return SCCNonlinearProblem(subprobs, explicitfuns, p, true; sys)
246+
end

0 commit comments

Comments
 (0)