Skip to content

Commit 11cb2e6

Browse files
committed
Optimize no. of equations when converting to semi-implicit ODE
1 parent 07644d8 commit 11cb2e6

File tree

2 files changed

+30
-56
lines changed

2 files changed

+30
-56
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 28 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -168,23 +168,17 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
168168
removed_eqs = Int[]
169169
removed_vars = Int[]
170170
diff_to_var = invview(var_to_diff)
171-
var2idx = Dict(reverse(en) for en in enumerate(fullvars))
172171
for var in 1:length(fullvars)
173172
dv = var_to_diff[var]
174173
dv === nothing && continue
175174
if var_eq_matching[var] !== SelectedState()
176175
dd = fullvars[dv]
177-
# TODO: figure this out structurally
176+
# TODO: check if observed has it
178177
v_t = diff2term(unwrap(dd))
179-
v_t_idx = get(var2idx, v_t, nothing)
180-
if v_t_idx isa Int
181-
substitute_vars!(graph, ((dv => v_t_idx),), idx_buffer, sub_callback!)
182-
else
183-
for eq in 𝑑neighbors(graph, dv)
184-
neweqs[eq] = substitute(neweqs[eq], fullvars[dv] => v_t)
185-
end
186-
fullvars[dv] = v_t
178+
for eq in 𝑑neighbors(graph, dv)
179+
neweqs[eq] = substitute(neweqs[eq], fullvars[dv] => v_t)
187180
end
181+
fullvars[dv] = v_t
188182
# update the structural information
189183
diff_to_var[dv] = nothing
190184
end
@@ -251,7 +245,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
251245
# As a final note, in all the above cases where we need to introduce new
252246
# variables and equations, don't add them when they already exist.
253247

254-
var_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(fullvars))
255248
if ModelingToolkit.has_iv(state.sys)
256249
iv = get_iv(state.sys)
257250
D = Differential(iv)
@@ -304,8 +297,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
304297
# D(x) ~ x_t
305298
ogidx = var_to_diff[ogidx]
306299

307-
has_x_t = false
308-
x_t_idx::Union{Nothing, Int} = nothing
309300
dx_idx = var_to_diff[xidx]
310301
if dx_idx === nothing
311302
dx = D(x)
@@ -319,50 +310,28 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
319310
var_to_diff[xidx] = dx_idx
320311
else
321312
dx = fullvars[dx_idx]
322-
var_eq_matching[dx_idx] = unassigned
323-
324-
for eq in 𝑑neighbors(graph, dx_idx)
325-
vs = 𝑠neighbors(graph, eq)
326-
length(vs) == 2 || continue
327-
maybe_x_t_idx = vs[1] == dx_idx ? vs[2] : vs[1]
328-
maybe_x_t = fullvars[maybe_x_t_idx]
329-
difference = (neweqs[eq].lhs - neweqs[eq].rhs) - (dx - maybe_x_t)
330-
# if `eq` is in the form of `D(x) ~ x_t`
331-
if ModelingToolkit._iszero(difference)
332-
x_t_idx = maybe_x_t_idx
333-
x_t = maybe_x_t
334-
eq_idx = eq
335-
push!(order_lowering_eqs, eq_idx)
336-
has_x_t = true
337-
break
338-
end
339-
end
340313
end
341314

342-
if x_t_idx === nothing
343-
x_t = ModelingToolkit.lower_varname(ogx, iv, o)
344-
push!(fullvars, x_t)
345-
x_t_idx = add_vertex!(var_to_diff)
346-
add_vertex!(graph, DST)
347-
add_vertex!(solvable_graph, DST)
348-
@assert x_t_idx == ndsts(graph) == length(fullvars)
349-
push!(var_eq_matching, unassigned)
350-
end
351-
x_t_idx::Int
352-
353-
if !has_x_t
354-
push!(neweqs, dx ~ x_t)
355-
eq_idx = add_vertex!(eq_to_diff)
356-
push!(order_lowering_eqs, eq_idx)
357-
add_vertex!(graph, SRC)
358-
add_vertex!(solvable_graph, SRC)
359-
@assert eq_idx == nsrcs(graph) == length(neweqs)
360-
361-
add_edge!(solvable_graph, eq_idx, x_t_idx)
362-
add_edge!(solvable_graph, eq_idx, dx_idx)
363-
add_edge!(graph, eq_idx, x_t_idx)
364-
add_edge!(graph, eq_idx, dx_idx)
365-
end
315+
# TODO: check if it's already in observed
316+
x_t = ModelingToolkit.lower_varname(ogx, iv, o)
317+
push!(fullvars, x_t)
318+
x_t_idx = add_vertex!(var_to_diff)
319+
add_vertex!(graph, DST)
320+
add_vertex!(solvable_graph, DST)
321+
@assert x_t_idx == ndsts(graph) == length(fullvars)
322+
push!(var_eq_matching, unassigned)
323+
324+
push!(neweqs, dx ~ x_t)
325+
eq_idx = add_vertex!(eq_to_diff)
326+
push!(order_lowering_eqs, eq_idx)
327+
add_vertex!(graph, SRC)
328+
add_vertex!(solvable_graph, SRC)
329+
@assert eq_idx == nsrcs(graph) == length(neweqs)
330+
331+
add_edge!(solvable_graph, eq_idx, x_t_idx)
332+
add_edge!(solvable_graph, eq_idx, dx_idx)
333+
add_edge!(graph, eq_idx, x_t_idx)
334+
add_edge!(graph, eq_idx, dx_idx)
366335
# We use this info to substitute all `D(D(x))` or `D(x_t)` except
367336
# the `D(D(x)) ~ x_tt` equation to `x_tt`.
368337
# D(D(x)) D(x_t) x_tt
@@ -382,6 +351,10 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
382351
# substituted to `x_tt`.
383352
for idx in (ogidx, dx_idx)
384353
subidx = ((idx => x_t_idx),)
354+
# This handles case 2.2
355+
if var_eq_matching[idx] isa Int
356+
var_eq_matching[x_t_idx] = var_eq_matching[idx]
357+
end
385358
substitute_vars!(graph, subidx, idx_buffer, sub_callback!;
386359
exclude = order_lowering_eqs)
387360
substitute_vars!(solvable_graph, subidx, idx_buffer;

test/structural_transformation/index_reduction.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ for sys in [
157157
g => 9.8,
158158
]
159159

160-
prob_auto = ODEProblem(sys, u0, (0.0, 1.0), p)
160+
prob_auto = ODEProblem(sys, u0, (0.0, 0.5), p)
161161
sol = solve(prob_auto, FBDF())
162+
@test sol.retcode === :Success
162163
@test norm(sol[x] .^ 2 + sol[y] .^ 2 .- 1) < 1e-2
163164
end

0 commit comments

Comments
 (0)