Skip to content

Commit 8e2cd1f

Browse files
committed
[WIP] Add unoptimized compilation path
1 parent 3829ef9 commit 8e2cd1f

File tree

12 files changed

+347
-125
lines changed

12 files changed

+347
-125
lines changed

src/DAECompiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ module DAECompiler
2828
include("transform/state_selection.jl")
2929
include("transform/common.jl")
3030
include("transform/runtime.jl")
31+
include("transform/unoptimized.jl")
3132
include("transform/tearing/schedule.jl")
3233
include("transform/codegen/dae_factory.jl")
3334
include("transform/codegen/ode_factory.jl")

src/interface.jl

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,23 @@ function factory_gen(@nospecialize(fT), settings::Settings, world::UInt = Base.g
4141
end
4242

4343
# Select differential and algebraic states
44-
ret = top_level_state_selection!(tstate)
45-
46-
if isa(ret, UncompilableIPOResult)
47-
return Base.generated_body_to_codeinfo(
48-
Expr(:lambda, Any[:var"#self", :settings, :f], Expr(:block, Expr(:return, Expr(:call, throw, ret.error)))),
49-
@__MODULE__, false)
50-
end
51-
(diff_key, init_key) = ret
52-
53-
if settings.mode in (DAE, DAENoInit, ODE, ODENoInit)
54-
tearing_schedule!(tstate, ci, diff_key, world, settings)
55-
end
56-
if settings.mode in (InitUncompress, DAE, ODE)
57-
tearing_schedule!(tstate, ci, init_key, world, settings)
44+
if settings.skip_optimizations
45+
diff_key = torn_cache_key(tstate, settings)
46+
init_key = nothing
47+
else
48+
ret = top_level_state_selection!(tstate)
49+
if isa(ret, UncompilableIPOResult)
50+
return Base.generated_body_to_codeinfo(
51+
Expr(:lambda, Any[:var"#self", :settings, :f], Expr(:block, Expr(:return, Expr(:call, throw, ret.error)))),
52+
@__MODULE__, false)
53+
end
54+
(diff_key, init_key) = ret
55+
if settings.mode in (DAE, DAENoInit, ODE, ODENoInit)
56+
tearing_schedule!(tstate, ci, diff_key, world, settings)
57+
end
58+
if settings.mode in (InitUncompress, DAE, ODE)
59+
tearing_schedule!(tstate, ci, init_key, world, settings)
60+
end
5861
end
5962

6063
# Generate the IR implementation of `factory`, returning the DAEFunction/ODEFunction

src/problem_interface.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ function DAECProblem(f, init::Union{Vector, Tuple{Vararg{Pair}}}, tspan::Tuple{R
2525
force_inline_all=false,
2626
insert_stmt_debuginfo=false,
2727
insert_ssa_debuginfo=false,
28+
skip_optimizations=false,
2829
kwargs...)
29-
settings = Settings(; force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo)
30+
settings = Settings(; force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo, skip_optimizations)
3031
DAECProblem(f, init, guesses, tspan, kwargs, settings, missing, nothing, nothing)
3132
end
3233

@@ -35,13 +36,14 @@ function DAECProblem(f, tspan::Tuple{Real, Real} = (0., 1.);
3536
force_inline_all=false,
3637
insert_stmt_debuginfo=false,
3738
insert_ssa_debuginfo=false,
39+
skip_optimizations=false,
3840
kwargs...)
39-
settings = Settings(; force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo)
41+
settings = Settings(; force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo, skip_optimizations)
4042
DAECProblem(f, nothing, guesses, tspan, kwargs, settings, missing, nothing, nothing)
4143
end
4244

4345
function DiffEqBase.get_concrete_problem(prob::DAECProblem, isadaptive; kwargs...)
44-
settings = Settings(; mode=prob.init === nothing ? DAE : DAENoInit, prob.settings.force_inline_all, prob.settings.insert_stmt_debuginfo, prob.settings.insert_ssa_debuginfo)
46+
settings = Settings(; mode=prob.init === nothing ? DAE : DAENoInit, prob.settings.force_inline_all, prob.settings.insert_stmt_debuginfo, prob.settings.insert_ssa_debuginfo, prob.settings.skip_optimizations)
4547
(daef, differential_vars) = factory(Val(settings), prob.f)
4648

4749
u0 = zeros(length(differential_vars))
@@ -77,8 +79,9 @@ function ODECProblem(f, init::Union{Vector, Tuple{Vararg{Pair}}}, tspan::Tuple{R
7779
force_inline_all=false,
7880
insert_stmt_debuginfo=false,
7981
insert_ssa_debuginfo=false,
82+
skip_optimizations=false,
8083
kwargs...)
81-
settings = Settings(; force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo)
84+
settings = Settings(; force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo, skip_optimizations)
8285
ODECProblem(f, init, guesses, tspan, kwargs, settings, missing, nothing)
8386
end
8487

@@ -87,13 +90,14 @@ function ODECProblem(f, tspan::Tuple{Real, Real} = (0., 1.);
8790
force_inline_all=false,
8891
insert_stmt_debuginfo=false,
8992
insert_ssa_debuginfo=false,
93+
skip_optimizations=false,
9094
kwargs...)
91-
settings = Settings(; force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo)
95+
settings = Settings(; force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo, skip_optimizations)
9296
ODECProblem(f, nothing, guesses, tspan, kwargs, settings, missing, nothing)
9397
end
9498

9599
function DiffEqBase.get_concrete_problem(prob::ODECProblem, isadaptive; kwargs...)
96-
settings = Settings(; mode=prob.init === nothing ? ODE : ODENoInit, prob.settings.force_inline_all, prob.settings.insert_stmt_debuginfo, prob.settings.insert_ssa_debuginfo)
100+
settings = Settings(; mode=prob.init === nothing ? ODE : ODENoInit, prob.settings.force_inline_all, prob.settings.insert_stmt_debuginfo, prob.settings.insert_ssa_debuginfo, prob.settings.skip_optimizations)
97101
(odef, n) = factory(Val(settings), prob.f)
98102

99103
u0 = zeros(n)

src/reflection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ code_ad_by_type(@nospecialize(tt::Type); kwargs...) =
2828

2929
function code_structure_by_type(@nospecialize(tt::Type); world::UInt = Base.tls_world_age(), result = false, matched = false, mode = DAE, force_inline_all = false, insert_stmt_debuginfo = false, insert_ssa_debuginfo = false, kwargs...)
3030
ci = _code_ad_by_type(tt; world, kwargs...)
31-
settings = Settings(; mode, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo)
31+
settings = Settings(; mode, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo, skip_optimizations)
3232
_result = structural_analysis!(ci, world, settings)
3333
isa(_result, UncompilableIPOResult) && throw(_result.error)
3434
!matched && return result ? _result : _result.ir

src/settings.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ struct Settings
1414
force_inline_all::Bool
1515
insert_stmt_debuginfo::Bool
1616
insert_ssa_debuginfo::Bool
17+
skip_optimizations::Bool
1718
end
18-
Settings(; mode::GenerationMode=DAE, force_inline_all::Bool=false, insert_stmt_debuginfo::Bool=false, insert_ssa_debuginfo::Bool=false) = Settings(mode, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo)
19+
Settings(; mode::GenerationMode=DAE, force_inline_all::Bool=false, insert_stmt_debuginfo::Bool=false, insert_ssa_debuginfo::Bool=false, skip_optimizations::Bool = false) = Settings(mode, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo, skip_optimizations)

src/transform/codegen/dae_factory.jl

Lines changed: 93 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -39,101 +39,51 @@ function make_daefunction(f, initf)
3939
DAEFunction(f; initialization_data = SciMLBase.OverrideInitData(NonlinearProblem((args...)->nothing, nothing, nothing), nothing, initf, nothing, nothing, Val{false}()))
4040
end
4141

42-
"""
43-
dae_factory_gen(ci, key)
44-
45-
Generate the `factory` function for CodeInstance `ci`, returning a DAEFunction.
46-
The resulting function is roughly:
47-
48-
```
49-
function factory(settings, f)
50-
# Run all parts of `f` that do not depend on state
51-
state_invariant_pieces = f_state_invariant()
52-
f! = %new_opaque_closure(f_rhs, state_invariant_pieces)
53-
DAEFunction(f!), differential_vars
42+
function continuous_variables(state::TransformationState)
43+
filter(var -> varkind(state, var) == Intrinsics.Continuous, 1:length(state.result.var_to_diff))
5444
end
55-
```
56-
57-
"""
58-
function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, settings::Settings, init_key::Union{TornCacheKey, Nothing})
59-
result = state.result
60-
torn_ci = find_matching_ci(ci->isa(ci.owner, TornIRSpec) && ci.owner.key == key, ci.def, world)
61-
torn_ir = torn_ci.inferred
62-
63-
(;ir_sicm) = torn_ir
6445

65-
ir_factory = copy(ci.inferred.ir)
66-
pushfirst!(ir_factory.argtypes, Settings)
67-
pushfirst!(ir_factory.argtypes, typeof(factory))
68-
compact = IncrementalCompact(ir_factory)
69-
70-
local line
71-
if ir_sicm !== nothing
72-
sicm_ci = find_matching_ci(ci->isa(ci.owner, SICMSpec) && ci.owner.key == key, ci.def, world)
73-
@assert sicm_ci !== nothing
74-
75-
line = result.ir[SSAValue(1)][:line]
76-
param_list = flatten_parameter!(Compiler.fallback_lattice, compact, ci.inferred.ir.argtypes[1:end], argn->Argument(2+argn), line, settings)
77-
sicm = @insert_instruction_here compact line settings invoke(param_list, sicm_ci)::Tuple
78-
else
79-
sicm = ()
80-
end
81-
82-
argt = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64}
83-
84-
daef_ci = rhs_finish!(state, ci, key, world, settings, 1)
46+
const SCIML_ABI = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64}
8547

86-
# Create a small opaque closure to adapt from SciML ABI to our own internal
87-
# ABI
48+
function sciml_to_internal_abi!(ir::IRCode, state::TransformationState, internal_ci::CodeInstance, key::TornCacheKey, var_eq_matching, settings::Settings)
49+
(; result, structure) = state
8850

8951
numstates = zeros(Int, Int(LastEquationStateKind))
90-
91-
all_states = Int[]
92-
for var = 1:length(result.var_to_diff)
93-
varkind(state, var) == Intrinsics.Continuous || continue
52+
for var in continuous_variables(state)
9453
kind = classify_var(result.var_to_diff, key, var)
9554
kind == nothing && continue
9655
numstates[kind] += 1
97-
(kind != AlgebraicDerivative) && push!(all_states, var)
9856
end
9957

100-
ir_oc = copy(ci.inferred.ir)
101-
empty!(ir_oc.argtypes)
102-
push!(ir_oc.argtypes, Tuple)
103-
push!(ir_oc.argtypes, Vector{Float64})
104-
push!(ir_oc.argtypes, Vector{Float64})
105-
push!(ir_oc.argtypes, Vector{Float64})
106-
push!(ir_oc.argtypes, SciMLBase.NullParameters)
107-
push!(ir_oc.argtypes, Float64)
58+
empty!(ir.argtypes)
59+
push!(ir.argtypes, Tuple) # opaque closure captures
60+
append!(ir.argtypes, fieldtypes(SCIML_ABI))
10861

109-
oc_compact = IncrementalCompact(ir_oc)
62+
compact = IncrementalCompact(ir)
11063

11164
# Zero the output
112-
line = ir_oc[SSAValue(1)][:line]
113-
@insert_instruction_here oc_compact line settings zero!(Argument(2))::VectorViewType
65+
line = ir[SSAValue(1)][:line]
66+
@insert_instruction_here compact line settings zero!(Argument(2))::VectorViewType
11467

11568
# out_du_mm, out_eq, in_u_mm, in_u_unassgn, in_du_unassgn, in_alg
11669
nassgn = numstates[AssignedDiff]
11770
ntotalstates = numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic]
118-
out_du_mm = @insert_instruction_here oc_compact line settings view(Argument(2), 1:nassgn)::VectorViewType
119-
out_eq = @insert_instruction_here oc_compact line settings view(Argument(2), (nassgn+1):ntotalstates)::VectorViewType
71+
out_du_mm = @insert_instruction_here compact line settings view(Argument(2), 1:nassgn)::VectorViewType
72+
out_eq = @insert_instruction_here compact line settings view(Argument(2), (nassgn+1):ntotalstates)::VectorViewType
12073

121-
(in_du_assgn, in_du_unassgn) = sciml_dae_split_du!(oc_compact, line, settings, Argument(3), numstates)
122-
(in_u_mm, in_u_unassgn, in_alg) = sciml_dae_split_u!(oc_compact, line, settings, Argument(4), numstates)
74+
(in_du_assgn, in_du_unassgn) = sciml_dae_split_du!(compact, line, settings, Argument(3), numstates)
75+
(in_u_mm, in_u_unassgn, in_alg) = sciml_dae_split_u!(compact, line, settings, Argument(4), numstates)
12376

12477
# Call DAECompiler-generated RHS with internal ABI
125-
oc_sicm = @insert_instruction_here oc_compact line settings getfield(Argument(1), 1)::Core.OpaqueClosure
78+
oc_sicm = @insert_instruction_here compact line settings getfield(Argument(1), 1)::Core.OpaqueClosure
12679

12780
# N.B: The ordering of arguments should match the ordering in the StateKind enum
128-
@insert_instruction_here oc_compact line settings (:invoke)(daef_ci, oc_sicm, (), in_u_mm, in_u_unassgn, in_du_unassgn, in_alg, out_du_mm, out_eq, Argument(6))::Nothing
129-
130-
# TODO: We should not have to recompute this here
131-
var_eq_matching = matching_for_key(state, key)
132-
(slot_assignments, var_assignment, eq_assignment) = assign_slots(state, key, var_eq_matching)
81+
@insert_instruction_here compact line settings (:invoke)(internal_ci, oc_sicm, (), in_u_mm, in_u_unassgn, in_du_unassgn, in_alg, out_du_mm, out_eq, Argument(6))::Nothing
13382

13483
# Manually apply mass matrix and implicit equations between selected states
135-
for v = 1:ndsts(state.structure.graph)
136-
vdiff = state.structure.var_to_diff[v]
84+
(_, var_assignment, _) = assign_slots(state, key, var_eq_matching)
85+
for v = 1:ndsts(structure.graph)
86+
vdiff = structure.var_to_diff[v]
13787
vdiff === nothing && continue
13888

13989
if var_eq_matching[v] !== SelectedState() || var_eq_matching[vdiff] !== SelectedState()
@@ -146,22 +96,81 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
14696
@assert kind == AssignedDiff
14797
@assert dkind in (AssignedDiff, UnassignedDiff)
14898

149-
v_val = @insert_instruction_here oc_compact line settings getindex(dkind == AssignedDiff ? in_u_mm : in_u_unassgn, dslot)::Any
150-
@insert_instruction_here oc_compact line settings setindex!(out_du_mm, v_val, slot)::Any
99+
v_val = @insert_instruction_here compact line settings getindex(dkind == AssignedDiff ? in_u_mm : in_u_unassgn, dslot)::Any
100+
@insert_instruction_here compact line settings setindex!(out_du_mm, v_val, slot)::Any
151101
end
152102

153-
bc = @insert_instruction_here oc_compact line settings Base.Broadcast.broadcasted(-, out_du_mm, in_du_assgn)::Any
154-
@insert_instruction_here oc_compact line settings Base.Broadcast.materialize!(out_du_mm, bc)::Nothing
103+
bc = @insert_instruction_here compact line settings Base.Broadcast.broadcasted(-, out_du_mm, in_du_assgn)::Any
104+
@insert_instruction_here compact line settings Base.Broadcast.materialize!(out_du_mm, bc)::Nothing
155105

156106
# Return
157-
@insert_instruction_here oc_compact line settings (return nothing)::Union{}
107+
@insert_instruction_here compact line settings (return nothing)::Union{}
158108

159-
ir_oc = Compiler.finish(oc_compact)
160-
maybe_rewrite_debuginfo!(ir_oc, settings)
161-
resize!(ir_oc.cfg.blocks, 1)
162-
empty!(ir_oc.cfg.blocks[1].succs)
163-
Compiler.verify_ir(ir_oc)
164-
oc = Core.OpaqueClosure(ir_oc)
109+
ir = Compiler.finish(compact)
110+
maybe_rewrite_debuginfo!(ir, settings)
111+
resize!(ir.cfg.blocks, 1)
112+
empty!(ir.cfg.blocks[1].succs)
113+
Compiler.verify_ir(ir)
114+
115+
@async @eval Main begin
116+
interface_ir = $ir
117+
end
118+
119+
return Core.OpaqueClosure(ir; slotnames = [:captures, :out, :du, :u, :p, :t])
120+
end
121+
122+
"""
123+
dae_factory_gen(ci, key)
124+
125+
Generate the `factory` function for CodeInstance `ci`, returning a DAEFunction.
126+
The resulting function is roughly:
127+
128+
```
129+
function factory(settings, f)
130+
# Run all parts of `f` that do not depend on state
131+
state_invariant_pieces = f_state_invariant()
132+
f! = %new_opaque_closure(f_rhs, state_invariant_pieces)
133+
DAEFunction(f!), differential_vars
134+
end
135+
```
136+
137+
"""
138+
function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, settings::Settings, init_key::Union{TornCacheKey, Nothing})
139+
result = state.result
140+
# TODO: We should not have to recompute this here
141+
142+
ir_factory = copy(ci.inferred.ir)
143+
pushfirst!(ir_factory.argtypes, Settings)
144+
pushfirst!(ir_factory.argtypes, typeof(factory))
145+
compact = IncrementalCompact(ir_factory)
146+
147+
# Create a small opaque closure to adapt from SciML ABI to our own internal ABI
148+
argt = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64}
149+
sicm = ()
150+
if settings.skip_optimizations
151+
daef_ci = rhs_finish_noopt!(state, ci, key, world, settings, 1)
152+
oc = sciml_to_internal_abi_noopt!(copy(ci.inferred.ir), state, daef_ci, settings)
153+
else
154+
var_eq_matching = matching_for_key(state, key)
155+
156+
torn_ci = find_matching_ci(ci->isa(ci.owner, TornIRSpec) && ci.owner.key == key, ci.def, world)
157+
torn_ir = torn_ci.inferred
158+
159+
(; ir_sicm) = torn_ir
160+
161+
local line
162+
if ir_sicm !== nothing
163+
sicm_ci = find_matching_ci(ci->isa(ci.owner, SICMSpec) && ci.owner.key == key, ci.def, world)
164+
@assert sicm_ci !== nothing
165+
166+
line = result.ir[SSAValue(1)][:line]
167+
param_list = flatten_parameter!(Compiler.fallback_lattice, compact, ci.inferred.ir.argtypes[1:end], argn->Argument(2+argn), line, settings)
168+
sicm = @insert_instruction_here compact line settings invoke(param_list, sicm_ci)::Tuple
169+
end
170+
171+
daef_ci = rhs_finish!(state, ci, key, world, settings, 1)
172+
oc = sciml_to_internal_abi!(copy(ci.inferred.ir), state, daef_ci, key, var_eq_matching, settings)
173+
end
165174

166175
line = result.ir[SSAValue(1)][:line]
167176

@@ -173,6 +182,7 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
173182

174183
new_oc = @insert_instruction_here compact line settings (:new_opaque_closure)(argt, Union{}, Nothing, true, oc_source_method, sicm)::Core.OpaqueClosure true
175184

185+
all_states = filter(var -> classify_var(result, key, var) != AlgebraicDerivative, continuous_variables(state))
176186
differential_states = Bool[v in key.diff_states for v in all_states]
177187

178188
if init_key !== nothing
@@ -192,6 +202,6 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
192202
empty!(ir_factory.cfg.blocks[1].succs)
193203
Compiler.verify_ir(ir_factory)
194204

195-
slotnames = [[:factory, :settings]; Symbol.(:arg, 1:(length(ir_factory.argtypes) - 2))]
205+
slotnames = [:factory, :settings, :f]
196206
return ir_factory, slotnames
197207
end

0 commit comments

Comments
 (0)