diff --git a/src/analysis/ADAnalyzer.jl b/src/analysis/ADAnalyzer.jl index 70ab31e..c871653 100644 --- a/src/analysis/ADAnalyzer.jl +++ b/src/analysis/ADAnalyzer.jl @@ -85,7 +85,7 @@ end return AnalyzedSource(ir, slotnames, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva) end -@override function Compiler.transform_result_for_local_cache(interp::ADAnalyzer, result::InferenceResult) +@override function Compiler.transform_result_for_local_cache(interp::ADAnalyzer, result::InferenceResult, edges::SimpleVector) if Compiler.result_is_constabi(interp, result) return nothing end diff --git a/src/transform/codegen/dae_factory.jl b/src/transform/codegen/dae_factory.jl index f53ed39..f1de56a 100644 --- a/src/transform/codegen/dae_factory.jl +++ b/src/transform/codegen/dae_factory.jl @@ -45,7 +45,7 @@ end const SCIML_ABI = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64} -function sciml_to_internal_abi!(ir::IRCode, state::TransformationState, internal_ci::CodeInstance, key::TornCacheKey, var_eq_matching, settings::Settings) +function sciml_to_internal_abi!(ir::IRCode, state::TransformationState, internal_ci::CodeInstance, key::TornCacheKey, var_eq_matching, world::UInt, settings::Settings) (; result, structure) = state numstates = zeros(Int, Int(LastEquationStateKind)) @@ -111,12 +111,7 @@ function sciml_to_internal_abi!(ir::IRCode, state::TransformationState, internal resize!(ir.cfg.blocks, 1) empty!(ir.cfg.blocks[1].succs) Compiler.verify_ir(ir) - - @async @eval Main begin - interface_ir = $ir - end - - return Core.OpaqueClosure(ir; slotnames = [:captures, :out, :du, :u, :p, :t]) + return optimized_opaque_closure(ir, world; slotnames = [:captures, :out, :du, :u, :p, :t]) end """ @@ -173,7 +168,7 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Unio end daef_ci = rhs_finish!(state, ci, key, world, settings, 1) - oc = sciml_to_internal_abi!(copy(ci.inferred.ir), state, daef_ci, key, var_eq_matching, settings) + oc = sciml_to_internal_abi!(copy(ci.inferred.ir), state, daef_ci, key, var_eq_matching, world, settings) end line = result.ir[SSAValue(1)][:line] diff --git a/src/transform/codegen/init_factory.jl b/src/transform/codegen/init_factory.jl index 996cb28..4606c09 100644 --- a/src/transform/codegen/init_factory.jl +++ b/src/transform/codegen/init_factory.jl @@ -88,7 +88,7 @@ function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEI @insert_instruction_here(oc_compact, line, settings, (return out_arr)::Vector{Float64}) ir_oc = Compiler.finish(oc_compact) - oc = Core.OpaqueClosure(ir_oc) + oc = optimized_opaque_closure(ir_oc, world) line = result.ir[SSAValue(1)][:line] diff --git a/src/transform/codegen/ode_factory.jl b/src/transform/codegen/ode_factory.jl index f725db7..c0b9414 100644 --- a/src/transform/codegen/ode_factory.jl +++ b/src/transform/codegen/ode_factory.jl @@ -139,7 +139,7 @@ function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn interface_ir = Compiler.finish(interface_ic) maybe_rewrite_debuginfo!(interface_ir, settings) Compiler.verify_ir(interface_ir) - interface_oc = Core.OpaqueClosure(interface_ir; slotnames = [:self, :du, :u, :p, :t]) + interface_oc = optimized_opaque_closure(interface_ir, world; slotnames = [:self, :du, :u, :p, :t]) line = result.ir[SSAValue(1)][:line] diff --git a/src/transform/common.jl b/src/transform/common.jl index 042a29e..2d6968f 100644 --- a/src/transform/common.jl +++ b/src/transform/common.jl @@ -98,6 +98,44 @@ function cache_dae_ci!(old_ci, src, debuginfo, abi, owner; rettype=Tuple) return daef_ci end +function optimized_opaque_closure(ir::IRCode, world::UInt; slotnames = nothing) + oc = Core.OpaqueClosure(ir) + adjust_world_bounds!(oc) + optimized_oc = optimize_opaque_closure!(oc, world; slotnames) + adjust_world_bounds!(optimized_oc) + return optimized_oc +end + +function optimize_opaque_closure!(oc::Core.OpaqueClosure, world::UInt; slotnames = nothing) + method = oc.source + ci = method.specializations.cache + ir = reinfer_and_inline(ci, world) + return Core.OpaqueClosure(ir; slotnames) +end + +# Not sure if/why this is necessary or even correct, but +# otherwise the `CodeInstance` bounds are outdated. +function adjust_world_bounds!(oc::Core.OpaqueClosure) + ci = oc.source.specializations.cache + @atomic ci.min_world = ci.inferred.min_world + @atomic ci.max_world = ci.inferred.max_world +end + +function reinfer_and_inline(ci::CodeInstance, world::UInt) + interp = Compiler.NativeInterpreter(world) + mi = Compiler.get_ci_mi(ci) + argtypes = collect(Any, mi.specTypes.parameters) + irsv = Compiler.IRInterpretationState(interp, ci, mi, argtypes, world) + @assert irsv !== nothing + for stmt in irsv.ir.stmts + stmt[:flag] |= Compiler.IR_FLAG_REFINED + end + Compiler.ir_abstract_constant_propagation(interp, irsv) + state = Compiler.InliningState(interp) + ir = Compiler.ssa_inlining_pass!(irsv.ir, state, Compiler.propagate_inbounds(irsv)) + return ir +end + function replace_call!(ir::Union{IRCode,IncrementalCompact}, idx::SSAValue, @nospecialize(new_call), settings::Settings, source) replace_call!(ir, idx, new_call) settings.insert_stmt_debuginfo || return new_call