Skip to content

Commit 52bf170

Browse files
authored
Preserve CallMeta results during inference (#640)
* Store and reuse `CallMeta` in a new `CthulhuCallInfo` struct * Add comments * Misc fixes * Bump version
1 parent 673ce13 commit 52bf170

File tree

5 files changed

+110
-30
lines changed

5 files changed

+110
-30
lines changed

src/CthulhuBase.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using WidthLimitedIO
1111

1212
using Core: MethodInstance, MethodMatch
1313
using Core.IR
14-
using .CC: AbstractInterpreter, ApplyCallInfo, CallInfo as CCCallInfo, ConstCallInfo,
14+
using .CC: AbstractInterpreter, CallMeta, ApplyCallInfo, CallInfo as CCCallInfo, ConstCallInfo,
1515
EFFECTS_TOTAL, Effects, IncrementalCompact, InferenceParams, InferenceResult,
1616
InferenceState, IRCode, LimitedAccuracy, MethodMatchInfo, MethodResultPure,
1717
NativeInterpreter, NoCallInfo, OptimizationParams, OptimizationState,
@@ -469,7 +469,7 @@ function _descend(term::AbstractTerminal, interp::AbstractInterpreter, curs::Abs
469469
if sourcenode !== nothing
470470
show_sub_callsites = let callsite=callsite
471471
map(info.callinfos) do ci
472-
p = Base.unwrap_unionall(ci.def.specTypes).parameters
472+
p = Base.unwrap_unionall(get_ci(ci).def.specTypes).parameters
473473
if isa(sourcenode, TypedSyntax.MaybeTypedSyntaxNode) && length(p) == length(JuliaSyntax.children(sourcenode)) + 1
474474
newnode = copy(sourcenode)
475475
for (i, child) in enumerate(JuliaSyntax.children(newnode))

src/callsite.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,14 @@ get_ci(gci::CuCallInfo) = get_ci(gci.ci)
172172
get_rt(gci::CuCallInfo) = get_rt(gci.ci)
173173
get_effects(gci::CuCallInfo) = get_effects(gci.ci)
174174

175+
struct CthulhuCallInfo <: CCCallInfo
176+
meta::CallMeta
177+
end
178+
CC.add_edges_impl(edges::Vector{Any}, info::CthulhuCallInfo) = CC.add_edges!(edges, info.meta.info)
179+
CC.nsplit_impl(info::CthulhuCallInfo) = CC.nsplit(info.meta.info)
180+
CC.getsplit_impl(info::CthulhuCallInfo, idx::Int) = CC.getsplit(info.meta.info, idx)
181+
CC.getresult_impl(info::CthulhuCallInfo, idx::Int) = CC.getresult(info.meta.info, idx)
182+
175183
struct Callsite
176184
id::Int # ssa-id
177185
info::CallInfo

src/codeview.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ function add_callsites!(d::AbstractDict, visited_cis::AbstractSet, diagnostics::
344344
# e.g. if f(x) = x is called with different types we print nothing.
345345
key = (mi.def.file, mi.def.line)
346346
if haskey(d, key)
347-
if !isnothing(d[key]) && mi != d[key].mi
347+
if !isnothing(d[key]) && mi != d[key].ci.def
348348
d[key] = nothing
349349
push!(diagnostics,
350350
TypedSyntax.Diagnostic(

src/interpreter.jl

Lines changed: 80 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ end
1919
const InferenceKey = Union{CodeInstance,InferenceResult} # TODO make this `CodeInstance` fully
2020
const InferenceDict{InferenceValue} = IdDict{InferenceKey, InferenceValue}
2121
const PC2Remarks = Vector{Pair{Int, String}}
22+
const PC2CallMeta = Dict{Int, CallMeta}
2223
const PC2Effects = Dict{Int, Effects}
2324
const PC2Excts = Dict{Int, Any}
2425

@@ -29,6 +30,7 @@ struct CthulhuInterpreter <: AbstractInterpreter
2930
native::AbstractInterpreter
3031
unopt::InferenceDict{InferredSource}
3132
remarks::InferenceDict{PC2Remarks}
33+
calls::InferenceDict{PC2CallMeta}
3234
effects::InferenceDict{PC2Effects}
3335
exception_types::InferenceDict{PC2Excts}
3436
end
@@ -39,6 +41,7 @@ function CthulhuInterpreter(interp::AbstractInterpreter=NativeInterpreter())
3941
interp,
4042
InferenceDict{InferredSource}(),
4143
InferenceDict{PC2Remarks}(),
44+
InferenceDict{PC2CallMeta}(),
4245
InferenceDict{PC2Effects}(),
4346
InferenceDict{PC2Excts}())
4447
end
@@ -96,6 +99,21 @@ function CC.update_exc_bestguess!(interp::CthulhuInterpreter, @nospecialize(exct
9699
frame::InferenceState)
97100
end
98101

102+
function CC.abstract_call(interp::CthulhuInterpreter, arginfo::CC.ArgInfo, sstate::CC.StatementState, sv::InferenceState)
103+
call = @invoke CC.abstract_call(interp::AbstractInterpreter, arginfo::CC.ArgInfo, sstate::CC.StatementState, sv::InferenceState)
104+
if isa(sv, InferenceState)
105+
key = get_inference_key(sv)
106+
if key !== nothing
107+
CC.Future{Any}(call, interp, sv) do call, interp, sv
108+
calls = get!(PC2CallMeta, interp.calls, key)
109+
calls[sv.currpc] = call
110+
nothing
111+
end
112+
end
113+
end
114+
return call
115+
end
116+
99117
function InferredSource(state::InferenceState)
100118
unoptsrc = copy(state.src)
101119
exct = state.result.exc_result
@@ -107,6 +125,66 @@ function InferredSource(state::InferenceState)
107125
exct)
108126
end
109127

128+
@static if VERSION v"1.13-"
129+
function _finishinfer!(frame::InferenceState, interp::CthulhuInterpreter, cycleid::Int, opt_cache::IdDict{MethodInstance, CodeInstance})
130+
return @invoke CC.finishinfer!(frame::InferenceState, interp::AbstractInterpreter, cycleid::Int, opt_cache::IdDict{MethodInstance, CodeInstance})
131+
end
132+
else
133+
function _finishinfer!(frame::InferenceState, interp::CthulhuInterpreter, cycleid::Int)
134+
return @invoke CC.finishinfer!(frame::InferenceState, interp::AbstractInterpreter, cycleid::Int)
135+
end
136+
end
137+
138+
function cthulhu_finish(result::Union{Nothing, InferenceResult}, frame::InferenceState, interp::CthulhuInterpreter)
139+
key = get_inference_key(frame)
140+
key === nothing && return result
141+
interp.unopt[key] = InferredSource(frame)
142+
143+
# Wrap `CallInfo`s with `CthulhuCallInfo`s post-inference.
144+
calls = get(interp.calls, key, nothing)
145+
isnothing(calls) && return result
146+
for (i, info) in enumerate(frame.stmt_info)
147+
info === NoCallInfo() && continue
148+
call = get(calls, i, nothing)
149+
call === nothing && continue
150+
if isa(info, CC.UnionSplitApplyCallInfo)
151+
# XXX: `UnionSplitApplyCallInfo` is specially handled in `CC.inline_apply!`,
152+
# so we can't shove it under a `CthulhuCallInfo`.
153+
frame.stmt_info[i] = pack_cthulhuinfo_in_unionsplit(call, info)
154+
else
155+
frame.stmt_info[i] = CthulhuCallInfo(call)
156+
end
157+
end
158+
159+
return result
160+
end
161+
162+
# Rebuild a `CC.UnionSplitApplyCallInfo` structure where inner `ApplyCallInfo`s wrap a `CthulhuCallInfo`.
163+
# Note that technically, `rt`/`exct`/`effects`/`refinements` are incorrect for each apply call as they
164+
# apply to the union split as a whole, not to individual branches. The idea is simply to preserve them.
165+
function pack_cthulhuinfo_in_unionsplit(call::CallMeta, info::CC.UnionSplitApplyCallInfo)
166+
infos = CC.ApplyCallInfo[]
167+
for apply in info.infos
168+
meta = CallMeta(call.rt, call.exct, call.effects, apply.call, call.refinements)
169+
push!(infos, CC.ApplyCallInfo(CthulhuCallInfo(meta), apply.arginfo))
170+
end
171+
return CC.UnionSplitApplyCallInfo(infos)
172+
end
173+
174+
# Build a `CthulhuCallInfo` structure wrapping `CC.UnionSplitApplyCallInfo`.
175+
function unpack_cthulhuinfo_from_unionsplit(info::CC.UnionSplitApplyCallInfo)
176+
isempty(info.infos) && return nothing
177+
apply = info.infos[1]
178+
isa(apply.call, CthulhuCallInfo) || return nothing
179+
(; rt, exct, effects, refinements) = apply.call.meta
180+
infos = CC.ApplyCallInfo[]
181+
for apply in info.infos
182+
push!(infos, CC.ApplyCallInfo(apply.call.meta.info, apply.arginfo))
183+
end
184+
call = CallMeta(rt, exct, effects, CC.UnionSplitApplyCallInfo(infos), refinements)
185+
return CthulhuCallInfo(call)
186+
end
187+
110188
function create_cthulhu_source(result::InferenceResult, effects::Effects)
111189
isa(result.src, OptimizationState) || return result.src
112190
opt = result.src
@@ -127,25 +205,9 @@ function set_cthulhu_source!(result::InferenceResult)
127205
end
128206

129207
@static if VERSION v"1.13-"
130-
CC.finishinfer!(state::InferenceState, interp::CthulhuInterpreter, cycleid::Int, opt_cache::IdDict{MethodInstance, CodeInstance}) = cthulhu_finish(CC.finishinfer!, state, interp, cycleid, opt_cache)
131-
function cthulhu_finish(@specialize(finishfunc), state::InferenceState, interp::CthulhuInterpreter, cycleid::Int, opt_cache::IdDict{MethodInstance, CodeInstance})
132-
res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter, cycleid::Int, opt_cache::IdDict{MethodInstance, CodeInstance})
133-
key = get_inference_key(state)
134-
if key !== nothing
135-
interp.unopt[key] = InferredSource(state)
136-
end
137-
return res
138-
end
208+
CC.finishinfer!(state::InferenceState, interp::CthulhuInterpreter, cycleid::Int, opt_cache::IdDict{MethodInstance, CodeInstance}) = cthulhu_finish(_finishinfer!(state, interp, cycleid, opt_cache), state, interp)
139209
else
140-
function cthulhu_finish(@specialize(finishfunc), state::InferenceState, interp::CthulhuInterpreter, cycleid::Int)
141-
res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter, cycleid::Int)
142-
key = get_inference_key(state)
143-
if key !== nothing
144-
interp.unopt[key] = InferredSource(state)
145-
end
146-
return res
147-
end
148-
CC.finishinfer!(state::InferenceState, interp::CthulhuInterpreter, cycleid::Int) = cthulhu_finish(CC.finishinfer!, state, interp, cycleid)
210+
CC.finishinfer!(state::InferenceState, interp::CthulhuInterpreter, cycleid::Int) = cthulhu_finish(_finishinfer!(state, interp, cycleid), state, interp)
149211
end
150212

151213
function CC.finish!(interp::CthulhuInterpreter, caller::InferenceState, validation_world::UInt, time_before::UInt64)

src/reflection.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,18 @@ function find_callsites(interp::AbstractInterpreter, CI::Union{CodeInfo,IRCode},
3333
if stmt_infos !== nothing && is_call_expr(stmt, optimize)
3434
info = stmt_infos[id]
3535
if info !== nothing
36-
rt = ignorelimited(argextype(SSAValue(id), CI, sptypes, slottypes))
36+
if isa(info, CC.UnionSplitApplyCallInfo)
37+
info = something(unpack_cthulhuinfo_from_unionsplit(info), info)
38+
end
39+
if isa(info, CthulhuCallInfo)
40+
# We have a `CallMeta` available.
41+
(; info, rt, exct, effects) = info.meta
42+
@assert !isa(info, CthulhuCallInfo)
43+
else
44+
rt = ignorelimited(argextype(SSAValue(id), CI, sptypes, slottypes))
45+
exct = isnothing(pc2excts) ? nothing : get(pc2excts, id, nothing)
46+
effects = nothing
47+
end
3748
# in unoptimized IR, there may be `slot = rhs` expressions, which `argextype` doesn't handle
3849
# so extract rhs for such an case
3950
local args = stmt.args
@@ -50,8 +61,7 @@ function find_callsites(interp::AbstractInterpreter, CI::Union{CodeInfo,IRCode},
5061
t = argextype(args[i], CI, sptypes, slottypes)
5162
argtypes[i] = ignorelimited(t)
5263
end
53-
exct = isnothing(pc2excts) ? nothing : get(pc2excts, id, nothing)
54-
callinfos = process_info(interp, info, argtypes, rt, optimize, exct)
64+
callinfos = process_info(interp, info, argtypes, rt, optimize, exct, effects)
5565
isempty(callinfos) && continue
5666
callsite = let
5767
if length(callinfos) == 1
@@ -146,8 +156,8 @@ end
146156

147157
function process_info(interp::AbstractInterpreter, @nospecialize(info::CCCallInfo),
148158
argtypes::ArgTypes, @nospecialize(rt), optimize::Bool,
149-
@nospecialize(exct))
150-
process_recursive(@nospecialize(newinfo)) = process_info(interp, newinfo, argtypes, rt, optimize, exct)
159+
@nospecialize(exct), effects::Union{Effects, Nothing})
160+
process_recursive(@nospecialize(newinfo)) = process_info(interp, newinfo, argtypes, rt, optimize, exct, effects)
151161

152162
if isa(info, MethodResultPure)
153163
if isa(info.info, CC.ReturnTypeCallInfo)
@@ -162,7 +172,7 @@ function process_info(interp::AbstractInterpreter, @nospecialize(info::CCCallInf
162172
if edge === nothing
163173
RTCallInfo(unwrapconst(argtypes[1]), argtypes[2:end], rt, exct)
164174
else
165-
effects = get_effects(edge)
175+
effects = @something(effects, get_effects(edge))
166176
EdgeCallInfo(edge, rt, effects, exct)
167177
end
168178
end for edge in info.edges if edge !== nothing]
@@ -183,7 +193,7 @@ function process_info(interp::AbstractInterpreter, @nospecialize(info::CCCallInf
183193
elseif isa(info, CC.InvokeCallInfo)
184194
edge = info.edge
185195
if edge !== nothing
186-
effects = get_effects(edge)
196+
effects = @something(effects, get_effects(edge))
187197
thisinfo = EdgeCallInfo(edge, rt, effects)
188198
innerinfo = process_const_info(interp, thisinfo, argtypes, rt, info.result, optimize, exct)
189199
else
@@ -194,7 +204,7 @@ function process_info(interp::AbstractInterpreter, @nospecialize(info::CCCallInf
194204
elseif isa(info, CC.OpaqueClosureCallInfo)
195205
edge = info.edge
196206
if edge !== nothing
197-
effects = get_effects(edge)
207+
effects = @something(effects, get_effects(edge))
198208
thisinfo = EdgeCallInfo(edge, rt, effects)
199209
innerinfo = process_const_info(interp, thisinfo, argtypes, rt, info.result, optimize, exct)
200210
else
@@ -212,7 +222,7 @@ function process_info(interp::AbstractInterpreter, @nospecialize(info::CCCallInf
212222
return CallInfo[]
213223
elseif isa(info, CC.ReturnTypeCallInfo)
214224
newargtypes = argtypes[2:end]
215-
callinfos = process_info(interp, info.info, newargtypes, unwrapType(widenconst(rt)), optimize, exct)
225+
callinfos = process_info(interp, info.info, newargtypes, unwrapType(widenconst(rt)), optimize, exct, effects)
216226
if length(callinfos) == 1
217227
vmi = only(callinfos)
218228
else

0 commit comments

Comments
 (0)