Skip to content

Commit 4402b42

Browse files
committed
Correctly compute state/equation indices for IPO
This includes a refactor that considers mapping to states, instead of individual mappings into u/du vectors
1 parent a81204f commit 4402b42

File tree

7 files changed

+98
-66
lines changed

7 files changed

+98
-66
lines changed

src/transform/codegen/dae_factory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Unio
147147
argt = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64}
148148
sicm = ()
149149
if settings.skip_optimizations
150-
daef_ci = rhs_finish_noopt!(state, ci, key, world, settings, 1)
150+
daef_ci = rhs_finish_noopt!(state, ci, key, world, settings)
151151
oc = sciml_to_internal_abi_noopt!(copy(ci.inferred.ir), state, daef_ci, settings)
152152
else
153153
# TODO: We should not have to recompute this here

src/transform/codegen/rhs.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
const VectorViewType = SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int}}, true}
2-
const VectorIntViewType = SubArray{Int, 1, Vector{Int}, Tuple{UnitRange{Int}}, true}
32

43
"""
54
struct RHSSpec

src/transform/common.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function replace_call!(ir::Union{IRCode,IncrementalCompact}, idx::SSAValue, @nos
116116
ir[idx][:type] = Any
117117
ir[idx][:info] = Compiler.NoCallInfo()
118118
ir[idx][:flag] |= Compiler.IR_FLAG_REFINED
119-
return new_call
119+
return idx
120120
end
121121

122122
function maybe_insert_debuginfo!(compact::IncrementalCompact, settings::Settings, source::LineNumberNode, previous = nothing, i = compact.result_idx)
@@ -129,6 +129,7 @@ function maybe_insert_debuginfo!(debuginfo::DebugInfoStream, settings::Settings,
129129
end
130130

131131
function insert_debuginfo!(debuginfo::DebugInfoStream, i::Integer, source::LineNumberNode, previous)
132+
prev_edge_index = prev_edge_line = nothing
132133
if previous !== nothing && isa(previous, Tuple)
133134
prev_edge_index, prev_edge_line = previous[2], previous[3]
134135
end

src/transform/reconstruct.jl

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,47 +7,35 @@ function expand_residuals(state::TransformationState, key::TornCacheKey, compres
77
# (slot_assignments, var_assignment, _) = assign_slots(state, key, var_eq_matching)
88
for (eq, incidence) in enumerate(result.total_incidence)
99
if !is_var_part_known_linear(incidence)
10-
sign = get_linear_residual_sign(result, eq)
11-
push!(expanded, sign * compressed[i])
10+
push!(expanded, compressed[i])
1211
i += 1
1312
continue
1413
end
1514

1615
residual = 0.0
17-
sign = get_linear_residual_sign(result, eq)
1816
for (coeff, var) in zip(nonzeros(incidence.row), rowvals(incidence.row))
1917
var -= 1
2018
if var == 0
2119
value = t
2220
else
23-
vint = invview(structure.var_to_diff)[var]
24-
(slot, source) = vint === nothing ? (var, u) : (vint, du)
25-
vint !== nothing && (i += 1)
21+
is_diff = is_differential_variable(structure, var)
22+
source = ifelse(is_diff, du, u)
23+
slot = count(v -> is_differential_variable(structure, v) == is_diff, 1:var)
24+
eq == var && !is_diff && (i += 1)
2625
value = source[slot]
2726
end
2827
residual += value * coeff
2928
end
3029
constant_term = isa(incidence.typ, Const) ? incidence.typ.val::Float64 : 0.0
31-
push!(expanded, constant_term + sign * residual)
30+
push!(expanded, constant_term + residual)
3231
end
3332
return expanded
3433
end
3534

36-
function get_linear_residual_sign(result::DAEIPOResult, eq::Int)
37-
# If a linear solved term appears with a positive coefficient,
38-
# the residual will be taken as the negative of the value provided to `always!`.
39-
# For example: ẋ₁ - x₁x₂ = 0
40-
# -ẋ₁ = -x₁x₂
41-
# ẋ₁ = -x₁x₂/-1
42-
# ẋ₁ = x₁x₂
43-
# 0 = x₁x₂ - ẋ₁ <-- residual
44-
incidence = result.total_incidence[eq]
45-
for (coeff, var) in zip(nonzeros(incidence.row), rowvals(incidence.row))
46-
var - 1 === eq || continue
47-
isa(coeff, Float64) || continue
48-
return coeff 0 ? -1 : 1
49-
end
50-
return 1
35+
function is_differential_variable(structure::DAESystemStructure, var)
36+
structure.var_to_diff[var] !== nothing && return false
37+
return invview(structure.var_to_diff)[var] !== nothing && return true
38+
@assert false
5139
end
5240

5341
function expand_residuals(f, residuals, u, du, t)
@@ -71,6 +59,7 @@ function extract_removed_states(state::TransformationState, key::TornCacheKey, t
7159
vint === nothing || key.diff_states === nothing || !in(vint, key.diff_states) || continue
7260
push!(removed_states, var)
7361
end
62+
# @sshow removed_states
7463
return removed_states
7564
end
7665

src/transform/unoptimized.jl

Lines changed: 73 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,67 +6,90 @@ function rhs_finish_noopt!(
66
key::UnoptimizedKey,
77
world::UInt,
88
settings::Settings,
9-
indexT=Int)
9+
equation_to_residual_mapping = 1:length(state.structure.eq_to_diff),
10+
variable_to_state_mapping = map_variables_to_states(state))
1011

1112
(; result, structure) = state
12-
result_ci = find_matching_ci(ci -> ci.inferred === key, ci.def, world)
13+
result_ci = find_matching_ci(ci -> ci.owner === key, ci.def, world)
1314
if result_ci !== nothing
1415
return result_ci
1516
end
1617

1718
ir = copy(result.ir)
18-
slotnames = [:captures, :vars, :out, :du, :u, :out_indices, :du_indices, :u_indices, :t]
19-
argtypes = [Tuple, Vector{Float64}, Vector{Float64}, Vector{Float64}, Vector{Float64}, VectorIntViewType, VectorIntViewType, VectorIntViewType, Float64]
19+
# TODO: use original function arguments too
20+
slotnames = [:captures, :out, :du, :u, :residuals, :states, :t]
21+
argtypes = [Tuple, Vector{Float64}, Vector{Float64}, Vector{Float64}, Vector{Int}, Vector{Int}, Float64]
2022
append!(empty!(ir.argtypes), argtypes)
21-
captures, vars, out, du, u, out_indices, du_indices, u_indices, t = Argument.(eachindex(slotnames))
22-
# TODO: use `out_indices`, `du_indices`, `u_indices`
23+
captures, out, du, u, residuals, states, t = Argument.(eachindex(slotnames))
2324
@assert length(slotnames) == length(ir.argtypes)
2425

2526
equations = Pair{SSAValue, Eq}[]
27+
callee_to_caller_eq_map = invert_eq_callee_mapping(result.eq_callee_mapping)
2628
compact = IncrementalCompact(ir)
2729

28-
for ((_, i), _) in compact
30+
# @sshow equation_to_residual_mapping variable_to_state_mapping
31+
32+
for ((old, i), _) in compact
2933
ssaidx = SSAValue(i)
3034
inst = compact[ssaidx]
3135
stmt = inst[:stmt]
3236
type = inst[:type]
3337
line = inst[:line]
3438

35-
if is_known_invoke(stmt, Intrinsics.equation, compact)
39+
if i == 1
40+
# @insert_instruction_here(compact, nothing, settings, println("Residuals: ", residuals)::Any)
41+
# @insert_instruction_here(compact, nothing, settings, println("States: ", states)::Any)
42+
replace_uses!(compact, (old, inst) => SSAValue(3))
43+
end
44+
45+
if is_known_invoke_or_call(stmt, Intrinsics.variable, compact)
46+
var = idnum(type)
47+
index = @insert_instruction_here(compact, line, settings, getindex(states, var)::Int)
48+
value = @insert_instruction_here(compact, line, settings, getindex(u, index)::Float64)
49+
replace_uses!(compact, (old, inst) => value)
50+
# @insert_instruction_here(compact, line, settings, println("Variable (", var, "): ", value)::Float64)
51+
elseif is_known_invoke(stmt, Intrinsics.ddt, compact)
52+
var = idnum(type)
53+
index = @insert_instruction_here(compact, line, settings, getindex(states, var)::Int)
54+
value = @insert_instruction_here(compact, line, settings, getindex(du, index)::Float64)
55+
replace_uses!(compact, (old, inst) => value)
56+
# @insert_instruction_here(compact, line, settings, println("Variable derivative (", var, " := ", invview(structure.var_to_diff)[var], "′): ", value)::Any)
57+
elseif is_known_invoke(stmt, Intrinsics.equation, compact)
3658
push!(equations, ssaidx => type::Eq)
3759
inst[:stmt] = nothing
38-
elseif is_known_invoke(stmt, Intrinsics.ddt, compact)
39-
var = invview(structure.var_to_diff)[idnum(type)]
40-
getdu = Expr(:call, getindex, du, var)
41-
replace_call!(compact, ssaidx, getdu, settings, @__SOURCE__)
42-
inst[:type] = Float64
4360
elseif is_equation_call(stmt, compact)
4461
callee, value = stmt.args[2], stmt.args[3]
4562
i = findfirst(x -> first(x) == callee, equations)::Int
4663
eq = last(equations[i])
47-
call = Expr(:call, setindex!, out, value, eq.id)
48-
replace_call!(compact, ssaidx, call, settings, @__SOURCE__)
49-
elseif is_known_invoke_or_call(stmt, Intrinsics.variable, compact)
50-
var = idnum(type)
51-
call = Expr(:call, getindex, u, var)
52-
replace_call!(compact, ssaidx, call, settings, @__SOURCE__)
53-
inst[:type] = Float64
64+
index = @insert_instruction_here(compact, line, settings, getindex(residuals, eq.id)::Int)
65+
ret = @insert_instruction_here(compact, line, settings, setindex!(out, value, index)::Any)
66+
replace_uses!(compact, (old, inst) => ret)
67+
# @insert_instruction_here(compact, line, settings, println("Residuals (index = ", index, ", value = ", value, "): ", residuals)::Any)
5468
elseif is_known_invoke_or_call(stmt, Intrinsics.sim_time, compact)
5569
inst[:stmt] = t
5670
elseif is_known_invoke_or_call(stmt, Intrinsics.epsilon, compact)
5771
inst[:stmt] = 0.0
5872
elseif isexpr(stmt, :invoke)
59-
@sshow stmt
73+
info = inst[:info]::MappingInfo
6074
callee_ci, callee_f = stmt.args[1]::CodeInstance, stmt.args[2]
6175
callee_result = structural_analysis!(callee_ci, world, settings)
6276
callee_structure = make_structure_from_ipo(callee_result)
6377
callee_state = TransformationState(callee_result, callee_structure)
64-
callee_daef_ci = rhs_finish_noopt!(callee_state, callee_ci, UnoptimizedKey(), world, settings)
65-
callee_captures = ()
66-
# TODO: compute indices into `u`/`du`/`out`
67-
empty!(stmt.args)
68-
push!(stmt.args, callee_daef_ci, callee_captures, vars,
69-
out, du, u, out_indices, du_indices, u_indices, t)
78+
79+
callee_residuals = equation_to_residual_mapping[callee_to_caller_eq_map[StructuralSSARef(old)]]
80+
caller_variables = idnum.(info.mapping.var_coeffs)
81+
callee_states = variable_to_state_mapping[caller_variables]
82+
83+
callee_daef_ci = rhs_finish_noopt!(callee_state, callee_ci, UnoptimizedKey(), world, settings, callee_residuals, callee_states)
84+
call = @insert_instruction_here(compact, line, settings, (:invoke)(callee_daef_ci, callee_f,
85+
out,
86+
du,
87+
u,
88+
@insert_instruction_here(compact, line, settings, getindex(Int, callee_residuals...)::Vector{Int}),
89+
@insert_instruction_here(compact, line, settings, getindex(Int, callee_states...)::Vector{Int}),
90+
t)::type)
91+
# TODO: add `stmt.args[3:end]`
92+
replace_uses!(compact, (old, inst) => call)
7093
end
7194
type = inst[:type]
7295
if isa(type, Incidence) || isa(type, Eq)
@@ -75,9 +98,28 @@ function rhs_finish_noopt!(
7598
end
7699

77100
daef_ci = rhs_finish_ir!(Compiler.finish(compact), ci, settings, key, slotnames)
101+
# @sshow daef_ci.inferred
78102
return daef_ci
79103
end
80104

105+
function map_variables_to_states(state::TransformationState)
106+
(; structure) = state
107+
diff_to_var = invview(structure.var_to_diff)
108+
states = Int[]
109+
prev_state = 0
110+
for var in continuous_variables(state)
111+
ref = is_differential_variable(structure, var) ? diff_to_var[var] : var
112+
state = @something(get(states, ref, nothing), prev_state += 1)
113+
push!(states, state)
114+
end
115+
return states
116+
end
117+
118+
function replace_uses!(compact, ((old, inst), new))
119+
inst[:stmt] = nothing
120+
compact.ssa_rename[old] = new
121+
end
122+
81123
function sciml_to_internal_abi_noopt!(ir::IRCode, state::TransformationState, internal_ci::CodeInstance, settings::Settings)
82124
slotnames = [:captures, :out, :du, :u, :p, :t]
83125
captures, out, du, u, p, t = Argument.(eachindex(slotnames))
@@ -90,14 +132,11 @@ function sciml_to_internal_abi_noopt!(ir::IRCode, state::TransformationState, in
90132
line = ir[SSAValue(1)][:line]
91133

92134
internal_oc = @insert_instruction_here(compact, line, settings, getfield(captures, 1)::Core.OpaqueClosure)
93-
# TODO: Compute proper indices.
94135
neqs = length(state.structure.eq_to_diff)
95-
out_indices = @insert_instruction_here(compact, line, settings, view(out, 1:neqs)::VectorIntViewType)
96-
du_indices = @insert_instruction_here(compact, line, settings, view(du, 1:neqs)::VectorIntViewType)
97-
u_indices = @insert_instruction_here(compact, line, settings, view(u, 1:neqs)::VectorIntViewType)
98-
# TODO: Provide actual external variables.
99-
vars = @insert_instruction_here(compact, line, settings, getindex(Float64)::Vector{Float64})
100-
@insert_instruction_here(compact, line, settings, (:invoke)(internal_ci, internal_oc, vars, out, du, u, out_indices, du_indices, u_indices, t)::Nothing)
136+
nvars = length(state.structure.var_to_diff)
137+
residuals = @insert_instruction_here(compact, line, settings, getindex(Int, 1:neqs...)::Vector{Int})
138+
states = @insert_instruction_here(compact, line, settings, getindex(Int, map_variables_to_states(state)...)::Vector{Int})
139+
@insert_instruction_here(compact, line, settings, (:invoke)(internal_ci, internal_oc, out, du, u, residuals, states, t)::Nothing)
101140
@insert_instruction_here(compact, line, settings, (return nothing)::Union{})
102141

103142
ir = Compiler.finish(compact)

src/utils.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,19 @@ end
162162
@sshow stmt
163163
@sshow length(ir.stmts) typeof(val)
164164
165-
Drop-in replacement for `@show`, but using `jl_safe_printf` to avoid task switches.
165+
Drop-in replacement for `@show`, but using `Core.println` to avoid task switches.
166166
167167
This directly prints to C stdout; `stdout` redirects won't have any effect.
168168
"""
169169
macro sshow(exs...)
170170
blk = Expr(:block)
171171
for ex in exs
172-
push!(blk.args, :(Core.println($(sprint(Base.show_unquoted,ex)*" = "),
173-
repr(begin local value = $(esc(ex)) end))))
172+
push!(blk.args, quote
173+
value = $(esc(ex))
174+
Core.print($(sprint(Base.show_unquoted, ex)))
175+
Core.print(" = ")
176+
Core.println(sprint(print, value, context = :color => true))
177+
end)
174178
end
175179
isempty(exs) || push!(blk.args, :value)
176180
return blk

test/validation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ end
4949

5050
# IPO
5151

52+
u = [2.0]
53+
du = [3.0]
5254
residuals, expanded_residuals = compute_residual_vectors(() -> onecall!(), u, du; t = 1.0)
5355
@test residuals [1.0]
5456
@test residuals expanded_residuals
5557

5658
u = [2.0, 4.0]
5759
du = [3.0, 7.0]
58-
# ERROR: BoundsError: attempt to access 2-element Vector{Float64} at index [3]
59-
# (for `var = 3`)
60-
refresh(); residuals, expanded_residuals = compute_residual_vectors(twocall!, u, du; t = 1.0)
61-
@test residuals [1.0]
60+
residuals, expanded_residuals = compute_residual_vectors(twocall!, u, du; t = 1.0)
61+
@test residuals [1.0, 3.0]
6262
@test residuals expanded_residuals
6363
end;

0 commit comments

Comments
 (0)