Skip to content

Commit a81204f

Browse files
committed
Skip flattening in unoptimized context, WIP on IPO support
1 parent 24ef40d commit a81204f

File tree

4 files changed

+62
-20
lines changed

4 files changed

+62
-20
lines changed

src/analysis/structural.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -357,18 +357,19 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings
357357
opaque_eligible = false
358358
end
359359

360-
# Rewrite to flattened ABI
361-
compact[SSAValue(i)] = nothing
362-
compact.result_idx -= 1
363-
new_args = _flatten_parameter!(Compiler.optimizer_lattice(refiner), compact, callee_codeinst.inferred.ir.argtypes, arg->stmt.args[arg+1], line, settings)
364-
365-
new_call = insert_instruction_here!(compact, settings, @__SOURCE__,
366-
NewInstruction(Expr(:invoke, (StructuralSSARef(compact.result_idx), callee_codeinst), new_args...), stmtype, info, line, stmtflags))
367-
compact.ssa_rename[compact.idx - 1] = new_call
360+
if !settings.skip_optimizations
361+
# Rewrite to flattened ABI
362+
compact[SSAValue(i)] = nothing
363+
compact.result_idx -= 1
364+
new_args = _flatten_parameter!(Compiler.optimizer_lattice(refiner), compact, callee_codeinst.inferred.ir.argtypes, arg->stmt.args[arg+1], line, settings)
365+
new_call = insert_instruction_here!(compact, settings, @__SOURCE__,
366+
NewInstruction(Expr(:invoke, (StructuralSSARef(compact.result_idx), callee_codeinst), new_args...), stmtype, info, line, stmtflags))
367+
compact.ssa_rename[compact.idx - 1] = new_call
368+
end
368369

369370
cms = CallerMappingState(result, refiner.var_to_diff, refiner.varclassification, refiner.varkinds, eqclassification, eqkinds)
370-
err = add_internal_equations_to_structure!(refiner, cms, total_incidence, eq_callee_mapping, StructuralSSARef(new_call.id),
371-
result, mapping)
371+
err = add_internal_equations_to_structure!(refiner, cms, total_incidence, eq_callee_mapping,
372+
StructuralSSARef(i), result, mapping)
372373
if err !== true
373374
return UncompilableIPOResult(warnings, UnsupportedIRException(err, ir))
374375
end

src/transform/reconstruct.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ end
5353
function expand_residuals(f, residuals, u, du, t)
5454
result = @code_structure result=true f()
5555
structure = make_structure_from_ipo(result)
56-
state = TransformationState(result, structure, copy(result.total_incidence))
56+
state = TransformationState(result, structure)
5757
key, _ = top_level_state_selection!(state)
5858
return expand_residuals(state, key, residuals, u, du, t)
5959
end

src/transform/unoptimized.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ function rhs_finish_noopt!(
1818
slotnames = [:captures, :vars, :out, :du, :u, :out_indices, :du_indices, :u_indices, :t]
1919
argtypes = [Tuple, Vector{Float64}, Vector{Float64}, Vector{Float64}, Vector{Float64}, VectorIntViewType, VectorIntViewType, VectorIntViewType, Float64]
2020
append!(empty!(ir.argtypes), argtypes)
21-
captures, vars, out, du, u, out_indices, du_indics, u_indices, t = Argument.(eachindex(slotnames))
21+
captures, vars, out, du, u, out_indices, du_indices, u_indices, t = Argument.(eachindex(slotnames))
22+
# TODO: use `out_indices`, `du_indices`, `u_indices`
2223
@assert length(slotnames) == length(ir.argtypes)
2324

2425
equations = Pair{SSAValue, Eq}[]
@@ -45,17 +46,27 @@ function rhs_finish_noopt!(
4546
eq = last(equations[i])
4647
call = Expr(:call, setindex!, out, value, eq.id)
4748
replace_call!(compact, ssaidx, call, settings, @__SOURCE__)
48-
elseif is_known_invoke_or_call(stmt, variable, compact)
49+
elseif is_known_invoke_or_call(stmt, Intrinsics.variable, compact)
4950
var = idnum(type)
5051
call = Expr(:call, getindex, u, var)
5152
replace_call!(compact, ssaidx, call, settings, @__SOURCE__)
5253
inst[:type] = Float64
53-
elseif is_known_invoke_or_call(stmt, sim_time, compact)
54+
elseif is_known_invoke_or_call(stmt, Intrinsics.sim_time, compact)
5455
inst[:stmt] = t
55-
# TODO: process flattened variables
56-
# TODO: process other intrinsics (epsilon, etc)
57-
# else
58-
# replace_if_intrinsic!(compact, settings, ssaidx, nothing, nothing, nothing, t, nothing)
56+
elseif is_known_invoke_or_call(stmt, Intrinsics.epsilon, compact)
57+
inst[:stmt] = 0.0
58+
elseif isexpr(stmt, :invoke)
59+
@sshow stmt
60+
callee_ci, callee_f = stmt.args[1]::CodeInstance, stmt.args[2]
61+
callee_result = structural_analysis!(callee_ci, world, settings)
62+
callee_structure = make_structure_from_ipo(callee_result)
63+
callee_state = TransformationState(callee_result, callee_structure)
64+
callee_daef_ci = rhs_finish_noopt!(callee_state, callee_ci, UnoptimizedKey(), world, settings)
65+
callee_captures = ()
66+
# TODO: compute indices into `u`/`du`/`out`
67+
empty!(stmt.args)
68+
push!(stmt.args, callee_daef_ci, callee_captures, vars,
69+
out, du, u, out_indices, du_indices, u_indices, t)
5970
end
6071
type = inst[:type]
6172
if isa(type, Incidence) || isa(type, Eq)

test/validation.jl

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ const *ᵢ = Core.Intrinsics.mul_float
1111
const += Core.Intrinsics.add_float
1212
const -= Core.Intrinsics.sub_float
1313

14-
@noinline function f()
14+
function f()
1515
x₁ = continuous() # selected
1616
x₂ = continuous() # selected
1717
x₃ = continuous() # algebraic, optimized away
@@ -22,12 +22,42 @@ const -ᵢ = Core.Intrinsics.sub_float
2222
always!(x₄ *ᵢ x₄ -ddt(x₁))
2323
end
2424

25+
@noinline function onecall!()
26+
x = continuous()
27+
always!(ddt(x) - x)
28+
end
29+
30+
function twocall!()
31+
onecall!(); onecall!();
32+
return nothing
33+
end
34+
2535
@testset "Validation" begin
2636
refresh() # TODO: remove before merge
27-
f()
37+
38+
u = [2.0]
39+
du = [3.0]
40+
residuals, expanded_residuals = compute_residual_vectors(onecall!, u, du; t = 1.0)
41+
@test residuals [1.0]
42+
@test residuals expanded_residuals
43+
2844
u = [3.0, 1.0, 100.0, 4.0]
2945
du = [3.0, 0.0, 0.0, 0.0]
3046
residuals, expanded_residuals = compute_residual_vectors(f, u, du; t = 1.0)
3147
@test residuals [0.0, -3.0, 97.0, 13.0]
3248
@test residuals expanded_residuals
49+
50+
# IPO
51+
52+
residuals, expanded_residuals = compute_residual_vectors(() -> onecall!(), u, du; t = 1.0)
53+
@test residuals [1.0]
54+
@test residuals expanded_residuals
55+
56+
u = [2.0, 4.0]
57+
du = [3.0, 7.0]
58+
# ERROR: BoundsError: attempt to access 2-element Vector{Float64} at index [3]
59+
# (for `var = 3`)
60+
refresh(); residuals, expanded_residuals = compute_residual_vectors(twocall!, u, du; t = 1.0)
61+
@test residuals [1.0]
62+
@test residuals expanded_residuals
3363
end;

0 commit comments

Comments
 (0)