Skip to content

Commit be0d176

Browse files
committed
Iterate ssrm
1 parent 67b4472 commit be0d176

File tree

7 files changed

+120
-45
lines changed

7 files changed

+120
-45
lines changed

src/transform/codegen/dae_factory.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,31 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
138138
insert_node_here!(oc_compact,
139139
NewInstruction(Expr(:invoke, daef_ci, oc_sicm, (), in_u_mm, in_u_unassgn, in_du_unassgn, in_alg, out_du_mm, out_eq, Argument(6)), Nothing, line))
140140

141-
# Manually apply mass matrix
141+
# TODO: We should not have to recompute this here
142+
var_eq_matching = matching_for_key(result, key, state.structure)
143+
(slot_assignments, var_assignment, eq_assignment) = assign_slots(state, key, var_eq_matching)
144+
145+
# Manually apply mass matrix and implicit equations between selected states
146+
for v = 1:ndsts(state.structure.graph)
147+
vdiff = state.structure.var_to_diff[v]
148+
vdiff === nothing && continue
149+
150+
if var_eq_matching[v] !== SelectedState() || var_eq_matching[vdiff] !== SelectedState()
151+
# Solved variables were already handled above
152+
continue
153+
end
154+
155+
(kind, slot) = var_assignment[v]
156+
(dkind, dslot) = var_assignment[vdiff]
157+
@assert kind == AssignedDiff
158+
@assert dkind in (AssignedDiff, UnassignedDiff)
159+
160+
v_val = insert_node_here!(oc_compact,
161+
NewInstruction(Expr(:call, Base.getindex, dkind == AssignedDiff ? in_u_mm : in_u_unassgn, dslot), Any, line))
162+
insert_node_here!(oc_compact,
163+
NewInstruction(Expr(:call, Base.setindex!, out_du_mm, v_val, slot), Any, line))
164+
end
165+
142166
bc = insert_node_here!(oc_compact,
143167
NewInstruction(Expr(:call, Base.Broadcast.broadcasted, -, out_du_mm, du_assgn), Any, line))
144168
insert_node_here!(oc_compact,

src/transform/codegen/rhs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ function rhs_finish!(
195195
ir[SSAValue(i)][:inst] = Intrinsics.placeholder_equation
196196
elseif is_solved_variable(stmt)
197197
var = stmt.args[end-1]
198-
vint = invview(result.var_to_diff)[var]
198+
vint = invview(structure.var_to_diff)[var]
199199
if vint !== nothing && key.diff_states !== nothing && (vint in key.diff_states) && !(var in diff_states_in_callee)
200200
handle_contribution!(ir, inst, StateDiff, var_assignment[vint][2], arg_range, stmt.args[end])
201201
else

src/transform/state_selection.jl

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function StateSelection.linear_subsys_adjmat!(state::TransformationState)
1212
eadj = Vector{Int}[]
1313
cadj = Vector{Int}[]
1414
linear_equations = Vector{Int}()
15-
for (i, inc) in enumerate(state.result.total_incidence)
15+
for (i, inc) in enumerate(state.total_incidence)
1616
isa(inc, Const) && continue
1717
any(x->x === nonlinear || !isinteger(x), nonzeros(inc.row)) && continue
1818
(isa(inc.typ, Const) && iszero(inc.typ.val)) || continue
@@ -56,13 +56,24 @@ function StateSelection.eq_derivative!(state::TransformationState, eq)
5656
push!(state.total_incidence, structural_inc_ddt(s.var_to_diff, nothing, nothing, state.total_incidence[eq]))
5757
end
5858

59+
function StateSelection.var_derivative!(state::TransformationState, var)
60+
return StateSelection.var_derivative_graph!(state.structure, var)
61+
end
62+
5963
function baseeq(result, structure, eq)
6064
while eq > length(result.eqkinds)
6165
eq = invview(structure.eq_to_diff)[eq]
6266
end
6367
return eq
6468
end
6569

70+
function basevar(result, structure, var)
71+
while var > length(result.varkinds)
72+
var = invview(structure.var_to_diff)[var]
73+
end
74+
return var
75+
end
76+
6677
function eqkind(result, structure, eq)
6778
return result.eqkinds[baseeq(result, structure, eq)]
6879
end
@@ -72,7 +83,7 @@ function eqclassification(result, structure, eq)
7283
return result.eqclassification[baseeq(result, structure, eq)]
7384
end
7485

75-
function structural_transformation!(state::TransformationState)
86+
function ssrm!(state::TransformationState)
7687
ils = StateSelection.structural_singularity_removal!(state)
7788

7889
s = state.structure
@@ -82,11 +93,35 @@ function structural_transformation!(state::TransformationState)
8293
end
8394
state.total_incidence[e] = Incidence(Const(0.), IncidenceVector(MAX_EQS, map(x->x+1, ils.row_cols[ei]), Union{Float64, NonLinear}[Float64(x) for x in ils.row_vals[ei]]))
8495
end
96+
end
8597

86-
var_eq_matching = StateSelection.pantelides!(state;
87-
varfilter = var->state.result.varkinds[var] == Intrinsics.Continuous && !(var <= state.result.nexternalvars),
88-
eqfilter = eq->eqkind(state, eq) == Intrinsics.Always)
89-
return StateSelection.complete(var_eq_matching, nsrcs(state.structure.graph))
98+
function varkind(result, structure, var)
99+
while var > length(result.varkinds)
100+
var = invview(structure.var_to_diff)[var]
101+
end
102+
return result.varkinds[var]
103+
end
104+
varkind(state, var) = varkind(state.result, state.structure, var)
105+
106+
function structural_transformation!(state::TransformationState)
107+
first = true
108+
# This loop is required to handle situations where additional structural signularities arise in the differentiated
109+
# equations. We could probably do lot better here by not constantly recomputing the datastructures.
110+
while true
111+
neq_before = length(state.structure.eq_to_diff)
112+
var_eq_matching = StateSelection.pantelides!(state;
113+
varfilter = var->varkind(state, var) == Intrinsics.Continuous && !(var <= state.result.nexternalvars),
114+
eqfilter = eq->eqkind(state, eq) == Intrinsics.Always)
115+
116+
differentiated_any = neq_before != length(state.structure.eq_to_diff)
117+
if differentiated_any || first
118+
ssrm!(state)
119+
first = false
120+
continue
121+
end
122+
123+
return StateSelection.complete(var_eq_matching, nsrcs(state.structure.graph))
124+
end
90125
end
91126

92127
using StateSelection: Unassigned, SelectedState, unassigned
@@ -99,14 +134,6 @@ function top_level_state_selection!(tstate)
99134

100135
StateSelection.complete!(structure)
101136

102-
diffvars = result.varkinds .== Intrinsics.Continuous
103-
for param in param_vars
104-
diffvars[param] = false
105-
end
106-
107-
@assert length(diffvars) == ndsts(structure.graph) == length(structure.var_to_diff)
108-
varwhitelist = StateSelection.computed_highest_diff_variables(structure, diffvars)
109-
110137
## Part 1: Perform the selection of differential states and subsequent tearing of the
111138
# non-linear problem at every time step.
112139

@@ -134,16 +161,18 @@ function top_level_state_selection!(tstate)
134161
diff_key = TornCacheKey(diff_vars, alg_vars, param_vars, explicit_eqs, Vector{Pair{BitSet, BitSet}}())
135162
@assert matching_for_key(result, diff_key, structure) == var_eq_matching
136163

164+
varfilter(var) = varkind(result, structure, var) == Intrinsics.Continuous && !(var <= result.nexternalvars)
165+
137166
## Part 2: Perform the selection of differential states and subsequent tearing of the
138167
# non-linear problem at every time step.
139168
init_var_eq_matching = StateSelection.complete(StateSelection.maximal_matching(structure.graph;
140-
dstfilter = var->diffvars[var], srcfilter = eq->eqkind(result, structure, eq) in (Intrinsics.Always, Intrinsics.Initial)), nsrcs(structure.graph))
169+
dstfilter = varfilter, srcfilter = eq->eqkind(result, structure, eq) in (Intrinsics.Always, Intrinsics.Initial)), nsrcs(structure.graph))
141170
init_var_eq_matching = StateSelection.pss_graph_modia!(structure, init_var_eq_matching)
142171

143172
init_state_vars = BitSet()
144173
init_explicit_eqs = BitSet()
145174
for (v, match) in enumerate(init_var_eq_matching)
146-
diffvars[v] || continue
175+
varfilter(v) || continue
147176
if match === unassigned
148177
push!(init_state_vars, v)
149178
end

src/transform/tearing/schedule.jl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -235,23 +235,27 @@ function compute_eq_schedule(key::TornCacheKey, result, mss::StateSelection.Matc
235235
frontier = BitSet()
236236

237237
isempty(var_schedule) && (var_schedule = Pair{BitSet, BitSet}[available=>BitSet()])
238+
vargraph = DiCMOBiGraph{true}(structure.graph, var_eq_matching)
239+
240+
function may_enqueue_frontier_var!(var)
241+
isa(var_eq_matching[var], Int) || return
242+
(var in available) && return # Already processed
243+
(var in frontier) && return # Already queued
244+
# Check if this neighbor is ready
245+
if all(inn->(inn in available), inneighbors(vargraph, var))
246+
push!(frontier, var)
247+
end
248+
end
238249

239250
for sched in var_schedule
240251
eq_order = Union{Int, SSAValue}[]
241252
push!(eq_orders, eq_order)
242-
vargraph = DiCMOBiGraph{true}(structure.graph, var_eq_matching)
243253

244254
(in_vars, _) = sched
245255
union!(available, in_vars)
246256

247257
for neighbor in 1:ndsts(structure.graph)
248-
isa(var_eq_matching[neighbor], Int) || continue
249-
(neighbor in available) && continue # Already processed
250-
(neighbor in frontier) && continue # Already queued
251-
# Check if this neighbor is ready
252-
if all(inn->(inn in available), inneighbors(vargraph, neighbor))
253-
push!(frontier, neighbor)
254-
end
258+
may_enqueue_frontier_var!(neighbor)
255259
end
256260

257261
new_available = BitSet()
@@ -347,6 +351,11 @@ function compute_eq_schedule(key::TornCacheKey, result, mss::StateSelection.Matc
347351

348352
if !isempty(new_available)
349353
union!(available, new_available)
354+
for var in new_available
355+
for neighbor in outneighbors(vargraph, var)
356+
may_enqueue_frontier_var!(neighbor)
357+
end
358+
end
350359
setdiff!(frontier, new_available)
351360
empty!(new_available)
352361
continue
@@ -443,7 +452,10 @@ function assign_slots(state::TransformationState, key::TornCacheKey, var_eq_matc
443452
if kind == AlgebraicDerivative
444453
var_assignment[i] = kind => slot
445454
elseif kind == AssignedDiff && var_eq_matching !== nothing
446-
eq_assignment[var_eq_matching[state.structure.var_to_diff[i]]] = StateDiff => slot
455+
eq = var_eq_matching[state.structure.var_to_diff[i]]
456+
if isa(eq, Int)
457+
eq_assignment[eq] = StateDiff => slot
458+
end
447459
end
448460
end
449461

@@ -475,7 +487,7 @@ function matching_for_key(result::DAEIPOResult, key::TornCacheKey, structure = m
475487

476488
allow_init_eqs = key.diff_states === nothing
477489

478-
may_use_var(var) = var > result.nexternalvars && (diff_states === nothing || !(var in diff_states)) && !(var in alg_states) && result.varkinds[var] == Intrinsics.Continuous
490+
may_use_var(var) = var > result.nexternalvars && (diff_states === nothing || !(var in diff_states)) && !(var in alg_states) && varkind(result, structure, var) == Intrinsics.Continuous
479491
may_use_eq(eq) = !(eq in explicit_eqs) && eqclassification(result, structure, eq) != External && eqkind(result, structure, eq) in (allow_init_eqs ? (Intrinsics.Initial, Intrinsics.Always) : (Intrinsics.Always,))
480492

481493
# Max match is the (unique) tearing result given the choice of states
@@ -825,7 +837,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To
825837
ir_sicm = Compiler.finish(compact)
826838
end
827839

828-
var_sols = Vector{Any}(undef, length(result.var_to_diff))
840+
var_sols = Vector{Any}(undef, length(structure.var_to_diff))
829841

830842
for var in key.param_vars
831843
var_sols[var] = 0.0
@@ -847,7 +859,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To
847859
end
848860

849861
function insert_solved_var_here!(compact1, var, curval, line)
850-
if result.varclassification[var] != Owned
862+
if result.varclassification[basevar(result, structure, var)] != Owned
851863
return
852864
end
853865
insert_node_here!(compact1, NewInstruction(Expr(:call, solved_variable, var, curval), Nothing, line))

test/basic.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,16 @@ sol = solve(DAECProblem(pantelides, (1,) .=> 0.), DFBDF(autodiff=false))
4747
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], sol.t))
4848

4949
#= Structural Singularity Removal =#
50-
function ssm()
50+
function ssrm()
5151
a = continuous()
5252
b = continuous()
5353
abdot = ddt(a +ᵢ b)
5454
always!(a -ᵢ abdot)
5555
always!(b -ᵢ abdot)
5656
end
5757

58-
ssm()
59-
sol = solve(DAECProblem(ssm, (1,) .=> 1.), DFBDF(autodiff=false))
58+
ssrm()
59+
sol = solve(DAECProblem(ssrm, (1,) .=> 1.), DFBDF(autodiff=false))
6060
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], exp.(0.5sol.t)))
6161

6262
#= Pantelides from init =#

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
include("basic.jl")
22
include("ipo.jl")
3-
include("ssm.jl")
3+
include("ssrm.jl")

test/ssm.jl renamed to test/ssrm.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,22 @@ using OrdinaryDiffEq
1010
const += Core.Intrinsics.add_float
1111
const -= Core.Intrinsics.sub_float
1212

13-
function ssm2()
13+
function ssrm2()
1414
a = continuous()
1515
b = continuous()
1616
abdot = ddt(a +ᵢ b)
1717
always!(a -ᵢ abdot)
1818
always!(b +ᵢ abdot)
1919
end
2020

21-
ssm2()
22-
23-
# TODO: Currently broken
24-
# solve(DAECProblem(ssm2, (1,) .=> 1.), DFBDF(autodiff=false))
21+
ssrm2()
22+
# Doesn't DFBDF doens't like the degenerate case
23+
@test isempty(DAECompiler.factory(ssrm2)[2])
24+
@test_broken solve(DAECProblem(ssrm2, ()), DFBDF())
2525

2626
# This is example (7.30) from Taihei Oki "Computing Valuations of Determinants via Combinatorial Optimization: Applications to Differential Equations".
2727
# The system is index 4 and requires iterating pantelides/ssm.
28-
function ssm4()
28+
function ssrm4()
2929
x₁ = continuous()
3030
x₂ = continuous()
3131
x₃ = continuous()
@@ -37,14 +37,24 @@ function ssm4()
3737
ẍ₁ = ddt(ẋ₁)
3838
ẍ₂ = ddt(ẋ₂)
3939
ẍ₃ = ddt(ẋ₃)
40-
always!(ẍ₁+ẍ₂-(ẋ₁+ẋ₂)+x₄)
41-
always!(ẍ₁+ẍ₂+x₃)
40+
always!((ẍ₁+ẍ₂)-(ẋ₁+ẋ₂)+x₄)
41+
always!((ẍ₁+ẍ₂)+x₃)
4242
always!(x₂+ẍ₃+ẋ₄)
4343
always!(x₃+ẋ₄)
4444
end
4545

46-
ssm4()
46+
ssrm4()
47+
# We expect state selection here to pick (x₁, x₄, ẋ₁)
48+
# The system simplifies to ẍ₁ = ẋ₄ = ẋ₁ - x₄
49+
init = (1.,0.,1.)
50+
function analytic(init, t)
51+
c = init[3] - init[2]
52+
ẋ₁ = init[3] + c*t
53+
x₄ = init[2] + c*t
54+
x₁ = init[1] + init[3]*t + 1/2*c*t^2
55+
return (x₁, x₄, ẋ₁)
56+
end
57+
sol = solve(DAECProblem(ssrm4, (1,2,3) .=> init), DFBDF(autodiff=false))
58+
@test isapprox(sol[:, :]', mapreduce(t->[analytic(init, t)...]', vcat, sol.t), atol=1e-4)
4759

48-
# TODO: Currently broken
49-
# solve(DAECProblem(ssm4, (1,) .=> 1.), DFBDF(autodiff=false))
5060
end

0 commit comments

Comments
 (0)