Skip to content

Commit f79c093

Browse files
committed
Optimize emitted OpaqueClosure
1 parent e1294dc commit f79c093

File tree

4 files changed

+43
-10
lines changed

4 files changed

+43
-10
lines changed

src/transform/codegen/dae_factory.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545

4646
const SCIML_ABI = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64}
4747

48-
function sciml_to_internal_abi!(ir::IRCode, state::TransformationState, internal_ci::CodeInstance, key::TornCacheKey, var_eq_matching, settings::Settings)
48+
function sciml_to_internal_abi!(ir::IRCode, state::TransformationState, internal_ci::CodeInstance, key::TornCacheKey, var_eq_matching, world::UInt, settings::Settings)
4949
(; result, structure) = state
5050

5151
numstates = zeros(Int, Int(LastEquationStateKind))
@@ -111,12 +111,7 @@ function sciml_to_internal_abi!(ir::IRCode, state::TransformationState, internal
111111
resize!(ir.cfg.blocks, 1)
112112
empty!(ir.cfg.blocks[1].succs)
113113
Compiler.verify_ir(ir)
114-
115-
@async @eval Main begin
116-
interface_ir = $ir
117-
end
118-
119-
return Core.OpaqueClosure(ir; slotnames = [:captures, :out, :du, :u, :p, :t])
114+
return optimized_opaque_closure(ir, world; slotnames = [:captures, :out, :du, :u, :p, :t])
120115
end
121116

122117
"""
@@ -173,7 +168,7 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Unio
173168
end
174169

175170
daef_ci = rhs_finish!(state, ci, key, world, settings, 1)
176-
oc = sciml_to_internal_abi!(copy(ci.inferred.ir), state, daef_ci, key, var_eq_matching, settings)
171+
oc = sciml_to_internal_abi!(copy(ci.inferred.ir), state, daef_ci, key, var_eq_matching, world, settings)
177172
end
178173

179174
line = result.ir[SSAValue(1)][:line]

src/transform/codegen/init_factory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEI
8888
@insert_instruction_here(oc_compact, line, settings, (return out_arr)::Vector{Float64})
8989

9090
ir_oc = Compiler.finish(oc_compact)
91-
oc = Core.OpaqueClosure(ir_oc)
91+
oc = optimized_opaque_closure(ir_oc, world)
9292

9393
line = result.ir[SSAValue(1)][:line]
9494

src/transform/codegen/ode_factory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
139139
interface_ir = Compiler.finish(interface_ic)
140140
maybe_rewrite_debuginfo!(interface_ir, settings)
141141
Compiler.verify_ir(interface_ir)
142-
interface_oc = Core.OpaqueClosure(interface_ir; slotnames = [:self, :du, :u, :p, :t])
142+
interface_oc = optimized_opaque_closure(interface_ir, world; slotnames = [:self, :du, :u, :p, :t])
143143

144144
line = result.ir[SSAValue(1)][:line]
145145

src/transform/common.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,44 @@ function cache_dae_ci!(old_ci, src, debuginfo, abi, owner; rettype=Tuple)
9898
return daef_ci
9999
end
100100

101+
function optimized_opaque_closure(ir::IRCode, world::UInt; slotnames = nothing)
102+
oc = Core.OpaqueClosure(ir)
103+
adjust_world_bounds!(oc)
104+
optimized_oc = optimize_opaque_closure!(oc, world; slotnames)
105+
adjust_world_bounds!(optimized_oc)
106+
return optimized_oc
107+
end
108+
109+
function optimize_opaque_closure!(oc::Core.OpaqueClosure, world::UInt; slotnames = nothing)
110+
method = oc.source
111+
ci = method.specializations.cache
112+
ir = reinfer_and_inline(ci, world)
113+
return Core.OpaqueClosure(ir; slotnames)
114+
end
115+
116+
# Not sure if/why this is necessary or even correct, but
117+
# otherwise the `CodeInstance` bounds are outdated.
118+
function adjust_world_bounds!(oc::Core.OpaqueClosure)
119+
ci = oc.source.specializations.cache
120+
@atomic ci.min_world = ci.inferred.min_world
121+
@atomic ci.max_world = ci.inferred.max_world
122+
end
123+
124+
function reinfer_and_inline(ci::CodeInstance, world::UInt)
125+
interp = Compiler.NativeInterpreter(world)
126+
mi = Compiler.get_ci_mi(ci)
127+
argtypes = collect(Any, mi.specTypes.parameters)
128+
irsv = Compiler.IRInterpretationState(interp, ci, mi, argtypes, world)
129+
@assert irsv !== nothing
130+
for stmt in irsv.ir.stmts
131+
stmt[:flag] |= Compiler.IR_FLAG_REFINED
132+
end
133+
Compiler.ir_abstract_constant_propagation(interp, irsv)
134+
state = Compiler.InliningState(interp)
135+
ir = Compiler.ssa_inlining_pass!(irsv.ir, state, Compiler.propagate_inbounds(irsv))
136+
return ir
137+
end
138+
101139
function replace_call!(ir::Union{IRCode,IncrementalCompact}, idx::SSAValue, @nospecialize(new_call), settings::Settings, source)
102140
replace_call!(ir, idx, new_call)
103141
settings.insert_stmt_debuginfo || return new_call

0 commit comments

Comments
 (0)