Skip to content

Commit 24d2f4a

Browse files
serenity4aviatesktopolarity
authored
Defer global caching of CodeInstance to post-optimization step (#58343)
This PR extracts the caching improvements from #56687, implemented by @aviatesk. It essentially defers global caching to the post-optimization step, giving a temporary cache to the optimizer instead of relying on the global cache. The issue with caching globally before optimization is that any errors occurring within the optimizer may leave a partially initialized `CodeInstance` in the cache, which was meant to be updated post-optimization. Exceptions should not be thrown in the native compilation pipeline, but abstract interpreters extending optimization routines will frequently encounter some during an iterative development (see JuliaComputing/DAECompiler.jl#25). An extra benefit from deferring global caching is that the optimizer can then more safely update `CodeInstance`s with new inference information, as is the intent of #56687. --------- Co-authored-by: Shuhei Kadowaki <[email protected]> Co-authored-by: Cody Tapscott <[email protected]>
1 parent 0cb1adb commit 24d2f4a

File tree

2 files changed

+84
-36
lines changed

2 files changed

+84
-36
lines changed

Compiler/src/optimize.jl

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,46 @@ struct InliningState{Interp<:AbstractInterpreter}
147147
edges::Vector{Any}
148148
world::UInt
149149
interp::Interp
150+
opt_cache::IdDict{MethodInstance,CodeInstance}
150151
end
151-
function InliningState(sv::InferenceState, interp::AbstractInterpreter)
152-
return InliningState(sv.edges, frame_world(sv), interp)
152+
function InliningState(sv::InferenceState, interp::AbstractInterpreter,
153+
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
154+
return InliningState(sv.edges, frame_world(sv), interp, opt_cache)
153155
end
154-
function InliningState(interp::AbstractInterpreter)
155-
return InliningState(Any[], get_inference_world(interp), interp)
156+
function InliningState(interp::AbstractInterpreter,
157+
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
158+
return InliningState(Any[], get_inference_world(interp), interp, opt_cache)
159+
end
160+
161+
struct OptimizerCache{CodeCache}
162+
wvc::WorldView{CodeCache}
163+
owner
164+
opt_cache::IdDict{MethodInstance,CodeInstance}
165+
function OptimizerCache(
166+
wvc::WorldView{CodeCache},
167+
@nospecialize(owner),
168+
opt_cache::IdDict{MethodInstance,CodeInstance}) where CodeCache
169+
new{CodeCache}(wvc, owner, opt_cache)
170+
end
171+
end
172+
function get((; wvc, owner, opt_cache)::OptimizerCache, mi::MethodInstance, default)
173+
if haskey(opt_cache, mi)
174+
codeinst = opt_cache[mi]
175+
@assert codeinst.min_world wvc.worlds.min_world &&
176+
wvc.worlds.max_world codeinst.max_world &&
177+
codeinst.owner === owner
178+
@assert isdefined(codeinst, :inferred) && codeinst.inferred === nothing
179+
return codeinst
180+
end
181+
return get(wvc, mi, default)
156182
end
157183

158184
# get `code_cache(::AbstractInterpreter)` from `state::InliningState`
159-
code_cache(state::InliningState) = WorldView(code_cache(state.interp), state.world)
185+
function code_cache(state::InliningState)
186+
cache = WorldView(code_cache(state.interp), state.world)
187+
owner = cache_owner(state.interp)
188+
return OptimizerCache(cache, owner, state.opt_cache)
189+
end
160190

161191
mutable struct OptimizationResult
162192
ir::IRCode
@@ -183,13 +213,15 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
183213
bb_vartables::Vector{Union{Nothing,VarTable}}
184214
insert_coverage::Bool
185215
end
186-
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter)
187-
inlining = InliningState(sv, interp)
216+
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter,
217+
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
218+
inlining = InliningState(sv, interp, opt_cache)
188219
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, sv.mod,
189220
sv.sptypes, sv.slottypes, inlining, sv.cfg,
190221
sv.unreachable, sv.bb_vartables, sv.insert_coverage)
191222
end
192-
function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractInterpreter)
223+
function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractInterpreter,
224+
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
193225
# prepare src for running optimization passes if it isn't already
194226
nssavalues = src.ssavaluetypes
195227
if nssavalues isa Int
@@ -209,7 +241,7 @@ function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractIn
209241
mod = isa(def, Method) ? def.module : def
210242
# Allow using the global MI cache, but don't track edges.
211243
# This method is mostly used for unit testing the optimizer
212-
inlining = InliningState(interp)
244+
inlining = InliningState(interp, opt_cache)
213245
cfg = compute_basic_blocks(src.code)
214246
unreachable = BitSet()
215247
bb_vartables = Union{VarTable,Nothing}[]

Compiler/src/typeinfer.jl

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState, validation
110110
elseif isdefined(result, :ci)
111111
edges = result_edges(interp, caller)
112112
ci = result.ci
113+
mi = result.linfo
113114
# if we aren't cached, we don't need this edge
114115
# but our caller might, so let's just make it anyways
115116
if last(result.valid_worlds) >= validation_world
@@ -143,22 +144,27 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState, validation
143144
resize!(inferred_result.slottypes::Vector{Any}, nslots)
144145
resize!(inferred_result.slotnames, nslots)
145146
end
146-
inferred_result = maybe_compress_codeinfo(interp, result.linfo, inferred_result)
147+
inferred_result = maybe_compress_codeinfo(interp, mi, inferred_result)
147148
result.is_src_volatile = false
148149
elseif ci.owner === nothing
149150
# The global cache can only handle objects that codegen understands
150151
inferred_result = nothing
151152
end
152153
end
153154
if debuginfo === nothing
154-
debuginfo = DebugInfo(result.linfo)
155+
debuginfo = DebugInfo(mi)
155156
end
157+
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
158+
ipo_effects = encode_effects(result.ipo_effects)
156159
time_now = _time_ns()
157160
time_self_ns = caller.time_self_ns + (time_now - time_before)
158161
time_total = (time_now - caller.time_start - caller.time_paused) * 1e-9
159162
ccall(:jl_update_codeinst, Cvoid, (Any, Any, Int32, UInt, UInt, UInt32, Any, Float64, Float64, Float64, Any, Any),
160-
ci, inferred_result, const_flag, first(result.valid_worlds), last(result.valid_worlds), encode_effects(result.ipo_effects),
163+
ci, inferred_result, const_flag, min_world, max_world, ipo_effects,
161164
result.analysis_results, time_total, caller.time_caches, time_self_ns * 1e-9, debuginfo, edges)
165+
if is_cached(caller) # CACHE_MODE_GLOBAL
166+
cache_result!(interp, result, ci)
167+
end
162168
engine_reject(interp, ci)
163169
codegen = codegen_cache(interp)
164170
if !discard_src && codegen !== nothing && (isa(uncompressed, CodeInfo) || isa(uncompressed, OptimizationState))
@@ -171,7 +177,6 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState, validation
171177
# This is necessary to get decent bootstrapping performance
172178
# when compiling the compiler to inject everything eagerly
173179
# where codegen can start finding and using it right away
174-
mi = result.linfo
175180
if mi.def isa Method && isa_compileable_sig(mi) && is_cached(caller)
176181
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), ci, uncompressed)
177182
end
@@ -181,6 +186,11 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState, validation
181186
return nothing
182187
end
183188

189+
function cache_result!(interp::AbstractInterpreter, result::InferenceResult, ci::CodeInstance)
190+
mi = result.linfo
191+
code_cache(interp)[mi] = ci
192+
end
193+
184194
function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstance, src::CodeInfo)
185195
user_edges = src.edges
186196
edges = user_edges isa SimpleVector ? user_edges : user_edges === nothing ? Core.svec() : Core.svec(user_edges...)
@@ -215,11 +225,13 @@ function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstan
215225
end
216226

217227
function finish_nocycle(::AbstractInterpreter, frame::InferenceState, time_before::UInt64)
218-
finishinfer!(frame, frame.interp, frame.cycleid)
228+
opt_cache = IdDict{MethodInstance,CodeInstance}()
229+
finishinfer!(frame, frame.interp, frame.cycleid, opt_cache)
219230
opt = frame.result.src
220231
if opt isa OptimizationState # implies `may_optimize(caller.interp) === true`
221232
optimize(frame.interp, opt, frame.result)
222233
end
234+
empty!(opt_cache)
223235
validation_world = get_world_counter()
224236
finish!(frame.interp, frame, validation_world, time_before)
225237
if isdefined(frame.result, :ci)
@@ -249,10 +261,11 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei
249261
cycle_valid_worlds = intersect(cycle_valid_worlds, caller.world.valid_worlds)
250262
cycle_valid_effects = merge_effects(cycle_valid_effects, caller.ipo_effects)
251263
end
264+
opt_cache = IdDict{MethodInstance,CodeInstance}()
252265
for frameid = cycleid:length(frames)
253266
caller = frames[frameid]::InferenceState
254267
adjust_cycle_frame!(caller, cycle_valid_worlds, cycle_valid_effects)
255-
finishinfer!(caller, caller.interp, cycleid)
268+
finishinfer!(caller, caller.interp, cycleid, opt_cache)
256269
time_now = _time_ns()
257270
caller.time_self_ns += (time_now - time_before)
258271
time_before = time_now
@@ -273,6 +286,7 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei
273286
caller.time_paused = UInt64(0)
274287
caller.time_caches = 0.0
275288
end
289+
empty!(opt_cache)
276290
cycletop = frames[cycleid]::InferenceState
277291
time_start = cycletop.time_start
278292
validation_world = get_world_counter()
@@ -434,22 +448,6 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, mi::MethodInstance
434448
return ci
435449
end
436450

437-
function cache_result!(interp::AbstractInterpreter, result::InferenceResult, ci::CodeInstance)
438-
@assert isdefined(ci, :inferred)
439-
# check if the existing linfo metadata is also sufficient to describe the current inference result
440-
# to decide if it is worth caching this right now
441-
mi = result.linfo
442-
cache = WorldView(code_cache(interp), result.valid_worlds)
443-
if haskey(cache, mi)
444-
ci = cache[mi]
445-
# n.b.: accurate edge representation might cause the CodeInstance for this to be constructed later
446-
@assert isdefined(ci, :inferred)
447-
return false
448-
end
449-
code_cache(interp)[mi] = ci
450-
return true
451-
end
452-
453451
function cycle_fix_limited(@nospecialize(typ), sv::InferenceState, cycleid::Int)
454452
if typ isa LimitedAccuracy
455453
frames = sv.callstack::Vector{AbsIntState}
@@ -579,7 +577,8 @@ const empty_edges = Core.svec()
579577

580578
# inference completed on `me`
581579
# update the MethodInstance
582-
function finishinfer!(me::InferenceState, interp::AbstractInterpreter, cycleid::Int)
580+
function finishinfer!(me::InferenceState, interp::AbstractInterpreter, cycleid::Int,
581+
opt_cache::IdDict{MethodInstance, CodeInstance})
583582
# prepare to run optimization passes on fulltree
584583
@assert isempty(me.ip)
585584
# inspect whether our inference had a limited result accuracy,
@@ -635,7 +634,7 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter, cycleid::
635634
# disable optimization if we've already obtained very accurate result
636635
!result_is_constabi(interp, result)
637636
if doopt
638-
result.src = OptimizationState(me, interp)
637+
result.src = OptimizationState(me, interp, opt_cache)
639638
else
640639
result.src = me.src # for reflection etc.
641640
end
@@ -670,23 +669,40 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter, cycleid::
670669
rettype_const = nothing
671670
const_flags = 0x0
672671
end
672+
673673
di = nothing
674674
edges = empty_edges # `edges` will be updated within `finish!`
675675
ci = result.ci
676+
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
676677
ccall(:jl_fill_codeinst, Cvoid, (Any, Any, Any, Any, Int32, UInt, UInt, UInt32, Any, Any, Any),
677678
ci, widenconst(result_type), widenconst(result.exc_result), rettype_const, const_flags,
678-
first(result.valid_worlds), last(result.valid_worlds),
679+
min_world, max_world,
679680
encode_effects(result.ipo_effects), result.analysis_results, di, edges)
680681
if is_cached(me) # CACHE_MODE_GLOBAL
681-
cached_result = cache_result!(me.interp, result, ci)
682-
if !cached_result
682+
already_cached = is_already_cached(me.interp, result, ci)
683+
if already_cached
683684
me.cache_mode = CACHE_MODE_VOLATILE
685+
else
686+
opt_cache[result.linfo] = ci
684687
end
685688
end
686689
end
687690
nothing
688691
end
689692

693+
function is_already_cached(interp::AbstractInterpreter, result::InferenceResult, ci::CodeInstance)
694+
# check if the existing linfo metadata is also sufficient to describe the current inference result
695+
# to decide if it is worth caching this right now
696+
mi = result.linfo
697+
cache = WorldView(code_cache(interp), result.valid_worlds)
698+
if haskey(cache, mi)
699+
# n.b.: accurate edge representation might cause the CodeInstance for this to be constructed later
700+
@assert isdefined(cache[mi], :inferred)
701+
return true
702+
end
703+
return false
704+
end
705+
690706
# Iterate a series of back-edges that need registering, based on the provided forward edge list.
691707
# Back-edges are returned as (invokesig, item), where the item is a Binding, MethodInstance, or
692708
# MethodTable.

0 commit comments

Comments
 (0)