Skip to content

Commit 602adce

Browse files
committed
Port two ipo tests from v1
1 parent d85b916 commit 602adce

File tree

4 files changed

+78
-25
lines changed

4 files changed

+78
-25
lines changed

Manifest.toml

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/transform/state_selection.jl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,25 @@ function structural_transformation!(state::TransformationState)
119119
first = false
120120
continue
121121
end
122-
122+
123123
return StateSelection.complete(var_eq_matching, nsrcs(state.structure.graph))
124124
end
125125
end
126126

127127
using StateSelection: Unassigned, SelectedState, unassigned
128+
129+
struct StateInvariant; end
130+
StateSelection.BipartiteGraphs.overview_label(::Type{StateInvariant}) = ('P', "State Invariant / Parameter", :red)
131+
132+
struct InOut
133+
ordinal::Int
134+
end
135+
StateSelection.BipartiteGraphs.overview_label(::Type{InOut}) = ('#', "IPO in var / out eq", :green)
136+
StateSelection.BipartiteGraphs.overview_label(io::InOut) = (string(io.ordinal), "IPO in var / out eq", :green)
137+
138+
const IPOMatches = Union{Unassigned, SelectedState, StateInvariant, InOut}
139+
const IPOMatching = StateSelection.Matching{IPOMatches}
140+
128141
function top_level_state_selection!(tstate)
129142
(; result, structure) = tstate
130143
# For the top-level problem, all external vars are state-invariant, and we do no other fissioning
@@ -135,16 +148,20 @@ function top_level_state_selection!(tstate)
135148
StateSelection.complete!(structure)
136149

137150
## Part 1: Perform the selection of differential states and subsequent tearing of the
138-
# non-linear problem at every time step.
151+
# non-linear problem at every time step.
139152

140-
var_eq_matching = StateSelection.partial_state_selection_graph!(structure, highest_diff_max_match)
153+
var_eq_matching = convert(IPOMatching, StateSelection.partial_state_selection_graph!(structure, highest_diff_max_match))
141154

142155
diff_vars = BitSet()
143156
alg_vars = BitSet()
144157
explicit_eqs = BitSet()
145158

146159
for (v, match) in enumerate(var_eq_matching)
147-
v in param_vars && continue
160+
if v in param_vars
161+
@assert match === unassigned
162+
var_eq_matching[v] = StateInvariant()
163+
continue
164+
end
148165
if match === SelectedState()
149166
push!(diff_vars, v)
150167
elseif match === unassigned
@@ -164,14 +181,19 @@ function top_level_state_selection!(tstate)
164181
varfilter(var) = varkind(result, structure, var) == Intrinsics.Continuous && !(var <= result.nexternalvars)
165182

166183
## Part 2: Perform the selection of differential states and subsequent tearing of the
167-
# non-linear problem at every time step.
168-
init_var_eq_matching = StateSelection.complete(StateSelection.maximal_matching(structure.graph;
184+
# non-linear problem at every time step.
185+
init_var_eq_matching = StateSelection.complete(StateSelection.maximal_matching(structure.graph, IPOMatches;
169186
dstfilter = varfilter, srcfilter = eq->eqkind(result, structure, eq) in (Intrinsics.Always, Intrinsics.Initial)), nsrcs(structure.graph))
170-
init_var_eq_matching = StateSelection.pss_graph_modia!(structure, init_var_eq_matching)
187+
init_var_eq_matching = convert(IPOMatching, StateSelection.pss_graph_modia!(structure, init_var_eq_matching))
171188

172189
init_state_vars = BitSet()
173190
init_explicit_eqs = BitSet()
174191
for (v, match) in enumerate(init_var_eq_matching)
192+
if v in param_vars
193+
@assert match === unassigned
194+
init_var_eq_matching[v] = StateInvariant()
195+
continue
196+
end
175197
varfilter(v) || continue
176198
if match === unassigned
177199
push!(init_state_vars, v)

src/transform/tearing/schedule.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,6 @@ end
5555

5656
Base.IteratorSize(::Type{Compiler.UseRefIterator}) = Base.SizeUnknown()
5757

58-
struct StateInvariant; end
59-
StateSelection.BipartiteGraphs.overview_label(::Type{StateInvariant}) = ('P', "State Invariant / Parameter", :red)
60-
61-
struct InOut
62-
ordinal::Int
63-
end
64-
StateSelection.BipartiteGraphs.overview_label(::Type{InOut}) = ('#', "IPO in var / out eq", :green)
65-
StateSelection.BipartiteGraphs.overview_label(io::InOut) = (string(io.ordinal), "IPO in var / out eq", :green)
66-
6758
function schedule_incidence!(compact, var_eq_matching, curval, ::Type, var, line; vars=nothing, schedule_missing_var! = nothing)
6859
# This just needs the linear part, which is `0` in `Type`
6960
return (curval, nothing)
@@ -511,7 +502,7 @@ function matching_for_key(result::DAEIPOResult, key::TornCacheKey, structure = m
511502
may_use_eq(eq) = !(eq in explicit_eqs) && eqclassification(result, structure, eq) != External && eqkind(result, structure, eq) in (allow_init_eqs ? (Intrinsics.Initial, Intrinsics.Always) : (Intrinsics.Always,))
512503

513504
# Max match is the (unique) tearing result given the choice of states
514-
var_eq_matching = StateSelection.complete(StateSelection.maximal_matching(structure.graph, Union{Unassigned, SelectedState, StateInvariant, InOut};
505+
var_eq_matching = StateSelection.complete(StateSelection.maximal_matching(structure.graph, IPOMatches;
515506
dstfilter = may_use_var, srcfilter = may_use_eq), nsrcs(structure.graph))
516507

517508
if diff_states !== nothing

test/ipo.jl

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,51 @@ function twocall!()
2626
end
2727

2828
twocall!()
29-
sol = solve(DAECProblem(twocall!, (1, 2) .=> 1.), IDA())
30-
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[1, :], exp.(sol.t)))
31-
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[2, :], exp.(sol.t)))
32-
sol = solve(ODECProblem(twocall!, (1, 2) .=> 1.), Rodas5(autodiff=false))
33-
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[1, :], exp.(sol.t)))
34-
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[2, :], exp.(sol.t)))
29+
dae_sol = solve(DAECProblem(twocall!, (1, 2) .=> 1.), IDA())
30+
ode_sol = solve(ODECProblem(twocall!, (1, 2) .=> 1.), Rodas5(autodiff=false))
31+
for (sol, i) in Iterators.product((dae_sol, ode_sol), 1:2)
32+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[i, :], exp.(sol.t)))
33+
end
34+
35+
#= + NonLinear =#
36+
@noinline function sin!()
37+
x = continuous()
38+
always!(ddt(x) - sin(x))
39+
end
40+
function sin2!()
41+
sin!(); sin!();
42+
return nothing
43+
end
44+
dae_sol = solve(DAECProblem(sin2!, (1, 2) .=> 1.), IDA())
45+
ode_sol = solve(ODECProblem(sin2!, (1, 2) .=> 1.), Rodas5(autodiff=false))
46+
for (sol, i) in Iterators.product((dae_sol, ode_sol), 1:2)
47+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[i, :], 2*acot.(exp.(-sol.t).*cot(1/2))))
48+
end
49+
50+
#= + SICM =#
51+
struct sicm!
52+
arg::Float64
53+
end
54+
55+
@noinline function (this::sicm!)()
56+
x = continuous()
57+
always!(ddt(x) - this.arg)
58+
end
59+
60+
struct sicm2!
61+
a::Float64
62+
b::Float64
63+
end
64+
65+
function (this::sicm2!)()
66+
sicm!(this.a)(); sicm!(this.b)();
67+
return nothing
68+
end
69+
dae_sol = solve(DAECProblem(sicm2!(1., 1.), (1, 2) .=> 1.), IDA())
70+
ode_sol = solve(ODECProblem(sicm2!(1., 1.), (1, 2) .=> 1.), Rodas5(autodiff=false))
71+
for (sol, i) in Iterators.product((dae_sol, ode_sol), 1:2)
72+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[i, :], 1. .+ sol.t))
73+
end
74+
3575

3676
end

0 commit comments

Comments
 (0)