Skip to content

Commit 036f74f

Browse files
committed
WIP
1 parent 968185b commit 036f74f

File tree

4 files changed

+51
-16
lines changed

4 files changed

+51
-16
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ using SparseArrays
3838
using NonlinearSolve
3939

4040
export tearing, dae_index_lowering, check_consistency
41+
export tearing_assignments, tearing_substitution
4142
export build_torn_function, build_observed_function, ODAEProblem
4243
export sorted_incidence_matrix
4344

src/structural_transformation/codegen.jl

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_s
9696
end
9797

9898
"""
99-
exprs = gen_nlsolve(eqs::Vector{Equation}, vars::Vector, u0map::Dict; checkbounds = true)
99+
exprs = gen_nlsolve(eqs::Vector{Equation}, vars::Vector, u0map::Dict; checkbounds = true, assignments)
100100
101101
Generate `SymbolicUtils` expressions for a root-finding function based on `eqs`,
102102
as well as a call to the root-finding solver.
@@ -112,13 +112,24 @@ exprs = [fname = f, numerical_nlsolve(fname, ...)]
112112
- `u0map`: A `Dict` which maps variables in `eqs` to values, e.g., `defaults(sys)` if `eqs = equations(sys)`.
113113
- `checkbounds`: Apply bounds checking in the generated code.
114114
"""
115-
function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
115+
function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true, assignments)
116116
isempty(vars) && throw(ArgumentError("vars may not be empty"))
117117
length(eqs) == length(vars) || throw(ArgumentError("vars must be of the same length as the number of equations to find the roots of"))
118118
rhss = map(x->x.rhs, eqs)
119119
# We use `vars` instead of `graph` to capture parameters, too.
120-
allvars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
121-
params = setdiff(allvars, vars) # these are not the subject of the root finding
120+
allvars = Set(Iterators.flatten(ModelingToolkit.vars(r) for r in rhss))
121+
vars_set = Set(vars)
122+
params = setdiff(allvars, vars_set) # these are not the subject of the root finding
123+
#needed_assignments = filter(a->a.lhs in params, assignments)
124+
needed_assignments = assignments[1:findlast(a->a.lhs in params, assignments)]
125+
params = setdiff(params, [a.lhs for a in needed_assignments])
126+
@show needed_assignments, params
127+
for a in needed_assignments
128+
ModelingToolkit.vars!(params, a.rhs)
129+
end
130+
params = setdiff(params, vars_set) # these are not the subject of the root finding
131+
@show params
132+
# inductor1₊v, inductor2₊v
122133

123134
# splatting to tighten the type
124135
u0 = []
@@ -141,10 +152,10 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
141152
f = Func(
142153
[
143154
DestructuredArgs(vars, inbounds=!checkbounds)
144-
DestructuredArgs(params, inbounds=!checkbounds)
155+
DestructuredArgs(collect(params), inbounds=!checkbounds)
145156
],
146157
[],
147-
isscalar ? rhss[1] : MakeArray(rhss, SVector)
158+
Let(needed_assignments, isscalar ? rhss[1] : MakeArray(rhss, SVector))
148159
) |> SymbolicUtils.Code.toexpr
149160

150161
# solver call contains code to call the root-finding solver on the function f
@@ -193,18 +204,21 @@ function build_torn_function(
193204

194205
states_idxs = collect(diffvars_range(s))
195206
mass_matrix_diag = ones(length(states_idxs))
207+
208+
assignments, bf_states = tearing_assignments(sys)
196209
torn_expr = []
210+
197211
defs = defaults(sys)
198212
nlsolve_scc_idxs = Int[]
199213

200214
needs_extending = false
201-
for (i, scc) in enumerate(var_sccs)
215+
@views for (i, scc) in enumerate(var_sccs)
202216
#torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
203217
torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] !== unassigned]
204218
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
205219
isempty(torn_eqs_idxs) && continue
206220
if length(torn_eqs_idxs) <= max_inlining_size
207-
append!(torn_expr, gen_nlsolve(eqs[torn_eqs_idxs], s.fullvars[torn_vars_idxs], defs, checkbounds=checkbounds))
221+
append!(torn_expr, gen_nlsolve(eqs[torn_eqs_idxs], s.fullvars[torn_vars_idxs], defs, checkbounds=checkbounds, assignments=assignments))
208222
push!(nlsolve_scc_idxs, i)
209223
else
210224
needs_extending = true
@@ -226,6 +240,7 @@ function build_torn_function(
226240

227241
states = s.fullvars[states_idxs]
228242
syms = map(Symbol, states_idxs)
243+
229244
pre = get_postprocess_fbody(sys)
230245

231246
expr = SymbolicUtils.Code.toexpr(
@@ -238,18 +253,22 @@ function build_torn_function(
238253
],
239254
[],
240255
pre(Let(
241-
torn_expr,
256+
[torn_expr; assignments],
242257
funbody
243258
))
244-
)
259+
),
260+
bf_states
245261
)
246262
if expression
247263
expr, states
248264
else
249-
observedfun = let sys = sys, dict = Dict()
265+
observedfun = let sys=sys, dict=Dict(), assignments=assignments, bf_states=bf_states
250266
function generated_observed(obsvar, u, p, t)
251267
obs = get!(dict, value(obsvar)) do
252-
build_observed_function(sys, obsvar, var_eq_matching, var_sccs, checkbounds=checkbounds)
268+
build_observed_function(sys, obsvar, var_eq_matching, var_sccs,
269+
checkbounds=checkbounds,
270+
assignments=assignments,
271+
bf_states=bf_states)
253272
end
254273
obs(u, p, t)
255274
end
@@ -286,7 +305,9 @@ function build_observed_function(
286305
sys, ts, var_eq_matching, var_sccs;
287306
expression=false,
288307
output_type=Array,
289-
checkbounds=true
308+
checkbounds=true,
309+
assignments,
310+
bf_states,
290311
)
291312

292313
if (isscalar = !(ts isa AbstractVector))
@@ -348,7 +369,7 @@ function build_observed_function(
348369
end
349370
pre = get_postprocess_fbody(sys)
350371

351-
ex = Func(
372+
ex = Code.toexpr(Func(
352373
[
353374
DestructuredArgs(diffvars, inbounds=!checkbounds)
354375
DestructuredArgs(parameters(sys), inbounds=!checkbounds)
@@ -357,13 +378,14 @@ function build_observed_function(
357378
[],
358379
pre(Let(
359380
[
381+
assignments
360382
collect(Iterators.flatten(solves))
361383
map(eq -> eq.lhseq.rhs, obs[1:maxidx])
362384
subs
363385
],
364386
isscalar ? ts[1] : MakeArray(ts, output_type)
365387
))
366-
) |> Code.toexpr
388+
), bf_states)
367389

368390
expression ? ex : @RuntimeGeneratedFunction(ex)
369391
end

src/structural_transformation/symbolics_tearing.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ function tearing_substitution(sys::AbstractSystem; simplify=false)
5858
@set! sys.substitutions = nothing
5959
end
6060

61+
function tearing_assignments(sys::AbstractSystem)
62+
if empty_substitutions(sys)
63+
assignments = []
64+
bf_states = Code.LazyState()
65+
else
66+
subs = get_substitutions(sys)
67+
assignments = [Assignment(eq.lhs, eq.rhs) for eq in subs]
68+
bf_states = Code.NameState(Dict(eq.lhs => Symbol(eq.lhs) for eq in subs))
69+
end
70+
return assignments, bf_states
71+
end
72+
6173
function solve_equation(eq, var, simplify)
6274
rhs = value(solve_for(eq, var; simplify=simplify, check=false))
6375
occursin(var, rhs) && error("solving $rhs for [$var] failed")

test/components.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Test
22
using ModelingToolkit, OrdinaryDiffEq
33
using ModelingToolkit.BipartiteGraphs
4-
using ModelingToolkit.StructuralTransformations: tearing_substitution
4+
using ModelingToolkit.StructuralTransformations
55

66
function check_contract(sys)
77
sys = tearing_substitution(sys)

0 commit comments

Comments
 (0)