Skip to content

Commit faf7d3f

Browse files
committed
WIP: Re-target IPO path to Jameson's branch
1 parent 24b782b commit faf7d3f

File tree

8 files changed

+94
-107
lines changed

8 files changed

+94
-107
lines changed

Manifest.toml

Lines changed: 2 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/analysis/compiler.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,8 +461,12 @@ function dae_result_for_inst(interp, inst::CC.Instruction)
461461
return result isa Union{DAEIPOResult, UncompilableIPOResult} ? result : nothing
462462
end
463463
else
464-
codeinst = CC.get(CC.code_cache(interp), mi, nothing)
465-
codeinst === nothing && return nothing
464+
if isa(mi, MethodInstance)
465+
codeinst = CC.get(CC.code_cache(interp), mi, nothing)
466+
codeinst === nothing && return nothing
467+
else
468+
codeinst = mi::CodeInstance
469+
end
466470
result = CC.traverse_analysis_results(codeinst) do @nospecialize result
467471
return result isa Union{DAEIPOResult, UncompilableIPOResult} ? result : nothing
468472
end
@@ -997,8 +1001,11 @@ end
9971001
return result
9981002
end
9991003

1004+
mi_or_ci = stmt.args[1]
1005+
isva = (isa(mi_or_ci, CodeInstance) ? mi_or_ci.def.def : mi_or_ci.def).isva
1006+
10001007
callee_argtypes = CC.va_process_argtypes(CC.optimizer_lattice(analysis_interp),
1001-
CC.collect_argtypes(analysis_interp, stmt.args[2:end], nothing, irsv), UInt(length(result.argtypes)), stmt.args[1].def.isva)
1008+
CC.collect_argtypes(analysis_interp, stmt.args[2:end], nothing, irsv), UInt(length(result.argtypes)), isva)
10021009
mapping = CalleeMapping(CC.optimizer_lattice(analysis_interp), callee_argtypes, result)
10031010
end
10041011
append!(warnings, result.warnings)

src/analysis/interpreter.jl

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -170,54 +170,47 @@ end
170170
#=CC.=#get_inference_world(interp::DAEInterpreter) = interp.world
171171
CC.get_inference_cache(interp::DAEInterpreter) = interp.inf_cache
172172

173-
if Base.__has_internal_change(v"1.12-alpha", :methodspecialization)
174-
"""
175-
struct AnalysisSpec
176-
177-
The cache partition for DAECompiler analysis results. This is essentially
178-
equivalent to a regular type inference, except that optimization ir prohibited
179-
from inlining any functions that have frules and we perform the DAE analysis
180-
on every ir after optimization.
181-
"""
182-
struct AnalysisSpec; end
183-
184-
"""
185-
struct RHSSpec
186-
187-
Cache partition for the RHS
188-
"""
189-
struct RHSSpec
190-
key::TornCacheKey
191-
ordinal::Int
192-
end
173+
"""
174+
struct AnalysisSpec
193175
194-
Base.show(io::IO, ms::MethodSpecialization{RHSSpec}) = print(io, "RHS Spec#$(ms.data.ordinal) for ", ms.def)
176+
The cache partition for DAECompiler analysis results. This is essentially
177+
equivalent to a regular type inference, except that optimization ir prohibited
178+
from inlining any functions that have frules and we perform the DAE analysis
179+
on every ir after optimization.
180+
"""
181+
struct AnalysisSpec; end
182+
183+
"""
184+
struct RHSSpec
195185
186+
Cache partition for the RHS
187+
"""
188+
struct RHSSpec
189+
key::TornCacheKey
190+
ordinal::Int
191+
end
196192

197-
"""
198-
struct SICMSpec
199193

200-
Cache partition for the state-invariant prologue
201-
"""
194+
"""
202195
struct SICMSpec
203-
key::TornCacheKey
204-
end
205196
206-
Base.show(io::IO, ms::MethodSpecialization{SICMSpec}) = print(io, "SICM Spec for ", ms.def)
197+
Cache partition for the state-invariant prologue
198+
"""
199+
struct SICMSpec
200+
key::TornCacheKey
201+
end
207202

208-
function CC.code_cache(interp::DAEInterpreter)
209-
if interp.ipo_analysis_mode
210-
return CC.WorldView(
211-
CC.InternalCodeCache(Core.MethodSpecialization{AnalysisSpec}),
212-
CC.WorldRange(CC.get_inference_world(interp)))
213-
else
214-
return interp.code_cache
215-
end
203+
204+
function CC.code_cache(interp::DAEInterpreter)
205+
if interp.ipo_analysis_mode
206+
return CC.WorldView(
207+
CC.InternalCodeCache(AnalysisSpec()),
208+
CC.WorldRange(CC.get_inference_world(interp)))
209+
else
210+
return interp.code_cache
216211
end
217-
else
218-
CC.cache_owner(interp::DAEInterpreter) = interp.code_cache
219-
CC.method_table(interp::DAEInterpreter) = interp.method_table
220212
end
213+
CC.cache_owner(interp::DAEInterpreter) = interp.ipo_analysis_mode ? AnalysisSpec() : interp.code_cache
221214

222215
# abstract interpretation
223216
# -----------------------
@@ -813,10 +806,14 @@ struct MappingInfo <: CC.CallInfo
813806
end
814807

815808
function _abstract_eval_invoke_inst(interp::DAEInterpreter, inst::Union{CC.Instruction, Nothing}, @nospecialize(stmt), irsv::IRInterpretationState)
816-
mi = stmt.args[1]
809+
invokee = stmt.args[1]
817810
RT = Pair{Any, Tuple{Bool, Bool}}
818811
good_effects = (true, true)
819-
m = mi.def
812+
if isa(invokee, Core.CodeInstance)
813+
m = invokee.def.def
814+
else
815+
m = invokee.def
816+
end
820817
if m === variable_method0 || m === variable_method1
821818
# Nothing to do - we'll read the incidence out of the ssavaluetypes
822819
return RT(nothing, good_effects)

src/cache.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,6 @@ end
5858
CalleeInternal
5959
end
6060

61-
@static if !Base.__has_internal_change(v"1.12-alpha", :methodspecialization)
62-
const MethodSpecialization = Core.MethodInstance
63-
else
64-
import Core: MethodSpecialization
65-
end
66-
6761
struct DAEIPOResult
6862
ir::IRCode
6963
extended_rt::Any
@@ -90,8 +84,8 @@ struct DAEIPOResult
9084
tearing_cache::Dict{TornCacheKey, TornIR}
9185

9286
# TODO: Should this be looked up via the regular code instance cache instead?
93-
sicm_cache::Dict{TornCacheKey, MethodSpecialization}
94-
dae_finish_cache::Dict{TornCacheKey, MethodSpecialization}
87+
sicm_cache::Dict{TornCacheKey, CodeInstance}
88+
dae_finish_cache::Dict{TornCacheKey, Vector{CodeInstance}}
9589
end
9690

9791
struct UncompilableIPOResult

src/transform/common.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ function compile_invokes!(ir, interp)
4040
inst = ir.stmts[i]
4141
e = inst[:inst]
4242
if isexpr(e, :invoke)
43-
mi = e.args[1]::MethodInstance
44-
if !CC.haskey(CC.code_cache(interp), mi)
45-
CC.typeinf_ext_toplevel(interp, mi, CC.SOURCE_MODE_ABI)
43+
mi = e.args[1]
44+
if isa(mi, MethodInstance)
45+
if !CC.haskey(CC.code_cache(interp), mi)
46+
CC.typeinf_ext_toplevel(interp, mi, CC.SOURCE_MODE_ABI)
47+
end
4648
end
4749
end
4850
end
@@ -212,6 +214,9 @@ function check_for_daecompiler_intrinstics(ir::IRCode)
212214
inst = ir[SSAValue(i)][:inst]
213215
isexpr(inst, :invoke) || continue
214216
mi = inst.args[1]
217+
if isa(mi, CodeInstance)
218+
mi = mi.def
219+
end
215220
if mi.def.module == DAECompiler.Intrinsics
216221
throw(UnexpectedIntrinsicException(inst))
217222
end

src/transform/dae_finish.jl

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,19 @@ end
262262

263263
const VectorViewType = SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int}}, true}
264264

265+
function cache_dae_ci!(old_ci, src, debuginfo, owner)
266+
daef_ci = CC.engine_reserve(old_ci.def, owner)
267+
ccall(:jl_fill_codeinst, Cvoid, (Any, Any, Any, Any, Int32, UInt, UInt, UInt32, Any, Any, Any),
268+
daef_ci, Tuple{}, Union{}, nothing, Int32(0),
269+
UInt(1)#=ci.min_world=#, old_ci.max_world,
270+
old_ci.ipo_purity_bits, nothing, nothing, CC.empty_edges)
271+
ccall(:jl_update_codeinst, Cvoid, (Any, Any, Int32, UInt, UInt, UInt32, Any, UInt8, Any, Any),
272+
daef_ci, src, Int32(0), UInt(1)#=ci.min_world=#, old_ci.max_world, old_ci.ipo_purity_bits,
273+
nothing, 0x0, debuginfo, CC.empty_edges)
274+
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), old_ci.def, daef_ci)
275+
return daef_ci
276+
end
277+
265278
function dae_finish_ipo!(
266279
interp,
267280
ci::CodeInstance,
@@ -288,6 +301,7 @@ function dae_finish_ipo!(
288301
old_daef_mi = nothing
289302
assigned_slots = falses(length(result.total_incidence))
290303

304+
cis = Vector{CodeInstance}()
291305
for (ir_ordinal, ir) in enumerate(torn.ir_seq)
292306
ir = torn.ir_seq[ir_ordinal]
293307

@@ -338,11 +352,11 @@ function dae_finish_ipo!(
338352
spec_data = stmt.args[1]
339353
callee_key = stmt.args[1][2]
340354
callee_ordinal = stmt.args[1][end]::Int
341-
callee_daef_mi = dae_finish_ipo!(interp, callee_ci, callee_key, callee_ordinal)
355+
callee_daef_cis = dae_finish_ipo!(interp, callee_ci, callee_key, callee_ordinal)
342356
# Allocate a continuous block of variables for all callee alg and diff states
343357

344358
empty!(stmt.args)
345-
push!(stmt.args, callee_daef_mi)
359+
push!(stmt.args, callee_daef_cis[1])
346360
push!(stmt.args, closure_env)
347361
push!(stmt.args, in_vars)
348362

@@ -402,33 +416,17 @@ function dae_finish_ipo!(
402416
widen_extra_info!(ir)
403417
src = ir_to_src(ir)
404418

405-
daef_mi = MethodSpecialization{RHSSpec}(ci.def, Tuple{}, Tuple{Tuple, Tuple, (VectorViewType for _ in arg_range)..., Float64})
406-
daef_mi.data = RHSSpec(key, ir_ordinal)
419+
abi = Tuple{Tuple, Tuple, (VectorViewType for _ in arg_range)..., Float64}
420+
owner = Core.ABIOverwrite(abi, RHSSpec(key, ir_ordinal))
421+
daef_ci = cache_dae_ci!(ci, src, src.debuginfo, owner)
407422

408-
daef_ci = CodeInstance(daef_mi, Tuple, Union{}, nothing,
409-
src, Int32(0), UInt(1)#=ci.min_world=#, ci.max_world, ci.ipo_purity_bits, ci.purity_bits,
410-
nothing, 0x0, src.debuginfo)
411-
412-
@atomic :release daef_mi.cache = daef_ci
413423
global nrhscompiles += 1
414-
415-
if old_daef_mi !== nothing
416-
@atomic :release old_daef_mi.next = daef_mi
417-
end
418-
old_daef_mi = daef_mi
419-
420-
if rhs_ms === nothing
421-
rhs_ms = daef_mi
422-
end
424+
push!(cis, daef_ci)
423425
end
424426

425-
result.dae_finish_cache[key] = rhs_ms
427+
result.dae_finish_cache[key] = cis
426428

427-
ms = rhs_ms
428-
while !isa(ms.data, RHSSpec) || ms.data.ordinal != ordinal
429-
ms = ms.next
430-
end
431-
return ms
429+
return cis
432430
end
433431

434432
function ir_to_src(ir::IRCode)
@@ -499,6 +497,7 @@ function dae_factory_gen(world::UInt, source::LineNumberNode, _, @nospecialize(f
499497
src.ssavaluetypes = length(src.code)
500498
src.min_world = @atomic codeinst.min_world
501499
src.max_world = @atomic codeinst.max_world
500+
src.edges = codeinst.edges
502501

503502
return src
504503
end
@@ -527,7 +526,7 @@ function dae_factory_gen(interp, ci::CodeInstance, key)
527526

528527
argt = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64}
529528

530-
daef_ci = dae_finish_ipo!(interp, ci, key, 1)
529+
daef_cis = dae_finish_ipo!(interp, ci, key, 1)
531530

532531
# Create a small opaque closure to adapt from SciML ABI to our own internal
533532
# ABI
@@ -582,7 +581,7 @@ function dae_factory_gen(interp, ci::CodeInstance, key)
582581
oc_sicm = insert_node_here!(oc_compact,
583582
NewInstruction(Expr(:call, getfield, Argument(1), 1), Tuple, line))
584583
insert_node_here!(oc_compact,
585-
NewInstruction(Expr(:invoke, daef_ci, oc_sicm, (), out_du_mm, out_eq, in_u_mm, in_u_unassgn, in_du_unassgn, in_alg, Argument(6)), Nothing, line))
584+
NewInstruction(Expr(:invoke, daef_cis[1], oc_sicm, (), out_du_mm, out_eq, in_u_mm, in_u_unassgn, in_du_unassgn, in_alg, Argument(6)), Nothing, line))
586585

587586
# Manually apply mass matrix
588587
bc = insert_node_here!(oc_compact,

src/transform/tearing_schedule_ipo.jl

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ function tearing_schedule!(interp, ci::CodeInstance, key::TornCacheKey)
619619
end
620620

621621
callee_codeinst = CC.get(CC.code_cache(interp), stmt.args[1], nothing)
622-
callee_sicm_mi = tearing_schedule!(interp, callee_codeinst, callee_key)
622+
callee_sicm_ci = tearing_schedule!(interp, callee_codeinst, callee_key)
623623

624624
inst[:type] = Any
625625
inst[:flag] = UInt32(0)
@@ -630,8 +630,8 @@ function tearing_schedule!(interp, ci::CodeInstance, key::TornCacheKey)
630630
(AssignedDiff, UnassignedDiff, Algebraic, Explicit))...)
631631
resize!(stmt.args, 1)
632632

633-
if !isdefined(callee_sicm_mi.cache, :rettype_const)
634-
new_stmt.args[1] = callee_sicm_mi
633+
if !isdefined(callee_sicm_ci, :rettype_const)
634+
new_stmt.args[1] = callee_sicm_ci
635635

636636
urs = userefs(new_stmt)
637637
for ur in urs
@@ -646,7 +646,7 @@ function tearing_schedule!(interp, ci::CodeInstance, key::TornCacheKey)
646646
state = insert_node_here!(compact, NewInstruction(inst; stmt=new_stmt, type=Tuple, flag=UInt32(0)))
647647
push!(stmt.args, SICMSSAValue(state.id))
648648
else
649-
push!(stmt.args, callee_sicm_mi.cache.rettype_const)
649+
push!(stmt.args, callee_sicm_ci.rettype_const)
650650
end
651651
elseif stmt === nothing || isa(stmt, ReturnNode)
652652
continue
@@ -948,24 +948,10 @@ function tearing_schedule!(interp, ci::CodeInstance, key::TornCacheKey)
948948
debuginfo = src.debuginfo
949949
end
950950

951-
sicm_mi = MethodSpecialization{SICMSpec}(ci.def, Tuple{}, sig)
952-
sicm_mi.data = SICMSpec(key)
951+
sicm_ci = cache_dae_ci!(ci, src, debuginfo, Core.ABIOverwrite(sig, SICMSpec(key)))
953952

954-
sicm_ci = CodeInstance(sicm_mi, Tuple, Union{}, ir_sicm === nothing ? () : nothing,
955-
src, ir_sicm === nothing ? Int32(0x3) : Int32(0), UInt(1)#=ci.min_world=#, ci.max_world, ci.ipo_purity_bits, ci.purity_bits,
956-
nothing, 0x0, debuginfo)
957-
958-
@atomic :release sicm_mi.cache = sicm_ci
959-
960-
result.sicm_cache[key] = sicm_mi
953+
result.sicm_cache[key] = sicm_ci
961954
result.tearing_cache[key] = TornIR(ir_sicm, irs)
962955

963-
cache_mi = ci.def
964-
while isdefined(cache_mi, :next)
965-
cache_mi = @atomic cache_mi.next
966-
end
967-
@atomic :release cache_mi.next = sicm_mi
968-
global nsicmcompiles += 1
969-
970-
return sicm_mi
956+
return sicm_ci
971957
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
using Test
22

3+
#=
34
@testset "state_mapping.jl" include("state_mapping.jl")
45
@testset "interpreter.jl" include("interpreter.jl")
56
@testset "compiler_and_lattice.jl" include("compiler_and_lattice.jl")
67
@testset "JITOpaqueClosures.jl" include("JITOpaqueClosures.jl")
78
@testset "robertson.jl" include("robertson.jl")
9+
=#
810
@testset "ipo.jl" include("ipo.jl")
911
@testset "lorenz.jl" include("lorenz_tests.jl")
1012
@testset "pendulum.jl" include("pendulum_tests.jl")

0 commit comments

Comments
 (0)