Skip to content

Commit 1232010

Browse files
committed
compiler: general refactor (#41633)
Separated from compiler-plugin prototyping. cherry-picked from 799136d
1 parent c85012a commit 1232010

File tree

7 files changed

+149
-119
lines changed

7 files changed

+149
-119
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 110 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -35,73 +35,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
3535
add_remark!(interp, sv, "Skipped call in throw block")
3636
return CallMeta(Any, false)
3737
end
38-
valid_worlds = WorldRange()
39-
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
40-
splitunions = 1 < unionsplitcost(argtypes) <= InferenceParams(interp).MAX_UNION_SPLITTING
41-
mts = Core.MethodTable[]
42-
fullmatch = Bool[]
43-
if splitunions
44-
split_argtypes = switchtupleunion(argtypes)
45-
applicable = Any[]
46-
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
47-
infos = MethodMatchInfo[]
48-
for arg_n in split_argtypes
49-
sig_n = argtypes_to_type(arg_n)
50-
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
51-
if mt === nothing
52-
add_remark!(interp, sv, "Could not identify method table for call")
53-
return CallMeta(Any, false)
54-
end
55-
mt = mt::Core.MethodTable
56-
matches = findall(sig_n, method_table(interp); limit=max_methods)
57-
if matches === missing
58-
add_remark!(interp, sv, "For one of the union split cases, too many methods matched")
59-
return CallMeta(Any, false)
60-
end
61-
push!(infos, MethodMatchInfo(matches))
62-
for m in matches
63-
push!(applicable, m)
64-
push!(applicable_argtypes, arg_n)
65-
end
66-
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
67-
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
68-
found = false
69-
for (i, mt′) in enumerate(mts)
70-
if mt′ === mt
71-
fullmatch[i] &= thisfullmatch
72-
found = true
73-
break
74-
end
75-
end
76-
if !found
77-
push!(mts, mt)
78-
push!(fullmatch, thisfullmatch)
79-
end
80-
end
81-
info = UnionSplitInfo(infos)
82-
else
83-
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
84-
if mt === nothing
85-
add_remark!(interp, sv, "Could not identify method table for call")
86-
return CallMeta(Any, false)
87-
end
88-
mt = mt::Core.MethodTable
89-
matches = findall(atype, method_table(interp, sv); limit=max_methods)
90-
if matches === missing
91-
# this means too many methods matched
92-
# (assume this will always be true, so we don't compute / update valid age in this case)
93-
add_remark!(interp, sv, "Too many methods matched")
94-
return CallMeta(Any, false)
95-
end
96-
push!(mts, mt)
97-
push!(fullmatch, _any(match->(match::MethodMatch).fully_covers, matches))
98-
info = MethodMatchInfo(matches)
99-
applicable = matches.matches
100-
valid_worlds = matches.valid_worlds
101-
applicable_argtypes = nothing
38+
39+
matches = find_matching_methods(argtypes, atype, method_table(interp, sv), InferenceParams(interp).MAX_UNION_SPLITTING, max_methods)
40+
if isa(matches, FailedMethodMatch)
41+
add_remark!(interp, sv, matches.reason)
42+
return CallMeta(Any, false)
10243
end
44+
45+
(; valid_worlds, applicable, info) = matches
10346
update_valid_age!(sv, valid_worlds)
104-
applicable = applicable::Array{Any,1}
10547
napplicable = length(applicable)
10648
rettype = Bottom
10749
edges = MethodInstance[]
@@ -142,7 +84,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
14284
if edge !== nothing
14385
push!(edges, edge)
14486
end
145-
this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i]
87+
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
14688
const_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
14789
if const_rt !== rt && const_rt rt
14890
rt = const_rt
@@ -164,7 +106,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
164106
end
165107
# try constant propagation with argtypes for this match
166108
# this is in preparation for inlining, or improving the return result
167-
this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i]
109+
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
168110
const_this_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
169111
if const_this_rt !== this_rt && const_this_rt this_rt
170112
this_rt = const_this_rt
@@ -272,7 +214,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
272214
# and avoid keeping track of a more complex result type.
273215
rettype = Any
274216
end
275-
add_call_backedges!(interp, rettype, edges, fullmatch, mts, atype, sv)
217+
add_call_backedges!(interp, rettype, edges, matches, atype, sv)
276218
if !isempty(sv.pclimitations) # remove self, if present
277219
delete!(sv.pclimitations, sv)
278220
for caller in sv.callers_in_cycle
@@ -283,24 +225,110 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
283225
return CallMeta(rettype, info)
284226
end
285227

286-
function add_call_backedges!(interp::AbstractInterpreter,
287-
@nospecialize(rettype),
288-
edges::Vector{MethodInstance},
289-
fullmatch::Vector{Bool}, mts::Vector{Core.MethodTable}, @nospecialize(atype),
290-
sv::InferenceState)
291-
if rettype === Any
292-
# for `NativeInterpreter`, we don't add backedges when a new method couldn't refine
293-
# (widen) this type
294-
return
228+
struct FailedMethodMatch
229+
reason::String
230+
end
231+
232+
struct MethodMatches
233+
applicable::Vector{Any}
234+
info::MethodMatchInfo
235+
valid_worlds::WorldRange
236+
mt::Core.MethodTable
237+
fullmatch::Bool
238+
end
239+
240+
struct UnionSplitMethodMatches
241+
applicable::Vector{Any}
242+
applicable_argtypes::Vector{Vector{Any}}
243+
info::UnionSplitInfo
244+
valid_worlds::WorldRange
245+
mts::Vector{Core.MethodTable}
246+
fullmatches::Vector{Bool}
247+
end
248+
249+
function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
250+
union_split::Int, max_methods::Int)
251+
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
252+
if 1 < unionsplitcost(argtypes) <= union_split
253+
split_argtypes = switchtupleunion(argtypes)
254+
infos = MethodMatchInfo[]
255+
applicable = Any[]
256+
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
257+
valid_worlds = WorldRange()
258+
mts = Core.MethodTable[]
259+
fullmatches = Bool[]
260+
for i in 1:length(split_argtypes)
261+
arg_n = split_argtypes[i]::Vector{Any}
262+
sig_n = argtypes_to_type(arg_n)
263+
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
264+
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
265+
mt = mt::Core.MethodTable
266+
matches = findall(sig_n, method_table; limit = max_methods)
267+
if matches === missing
268+
return FailedMethodMatch("For one of the union split cases, too many methods matched")
269+
end
270+
push!(infos, MethodMatchInfo(matches))
271+
for m in matches
272+
push!(applicable, m)
273+
push!(applicable_argtypes, arg_n)
274+
end
275+
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
276+
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
277+
found = false
278+
for (i, mt′) in enumerate(mts)
279+
if mt′ === mt
280+
fullmatches[i] &= thisfullmatch
281+
found = true
282+
break
283+
end
284+
end
285+
if !found
286+
push!(mts, mt)
287+
push!(fullmatches, thisfullmatch)
288+
end
289+
end
290+
return UnionSplitMethodMatches(applicable,
291+
applicable_argtypes,
292+
UnionSplitInfo(infos),
293+
valid_worlds,
294+
mts,
295+
fullmatches)
296+
else
297+
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
298+
if mt === nothing
299+
return FailedMethodMatch("Could not identify method table for call")
300+
end
301+
mt = mt::Core.MethodTable
302+
matches = findall(atype, method_table; limit = max_methods)
303+
if matches === missing
304+
# this means too many methods matched
305+
# (assume this will always be true, so we don't compute / update valid age in this case)
306+
return FailedMethodMatch("Too many methods matched")
307+
end
308+
fullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
309+
return MethodMatches(matches.matches,
310+
MethodMatchInfo(matches),
311+
matches.valid_worlds,
312+
mt,
313+
fullmatch)
295314
end
315+
end
316+
317+
function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype), edges::Vector{MethodInstance},
318+
matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
319+
sv::InferenceState)
320+
# for `NativeInterpreter`, we don't add backedges when a new method couldn't refine (widen) this type
321+
rettype === Any && return
296322
for edge in edges
297323
add_backedge!(edge, sv)
298324
end
299-
for (thisfullmatch, mt) in zip(fullmatch, mts)
300-
if !thisfullmatch
301-
# also need an edge to the method table in case something gets
302-
# added that did not intersect with any existing method
303-
add_mt_backedge!(mt, atype, sv)
325+
# also need an edge to the method table in case something gets
326+
# added that did not intersect with any existing method
327+
if isa(matches, MethodMatches)
328+
matches.fullmatch || add_mt_backedge!(matches.mt, atype, sv)
329+
else
330+
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
331+
thisfullmatch || add_mt_backedge!(mt, atype, sv)
304332
end
305333
end
306334
end

base/compiler/inferenceresult.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ end
1313
# for the provided `linfo` and `given_argtypes`. The purpose of this function is
1414
# to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`,
1515
# so that we can construct cache-correct `InferenceResult`s in the first place.
16-
function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override)
16+
function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override::Bool)
1717
@assert isa(linfo.def, Method) # ensure the next line works
1818
nargs::Int = linfo.def.nargs
1919
@assert length(given_argtypes) >= (nargs - 1)

base/compiler/optimize.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,11 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
196196
return true
197197
end
198198

199-
# Convert IRCode back to CodeInfo and compute inlining cost and sideeffects
199+
# compute inlining cost and sideeffects
200200
function finish(interp::AbstractInterpreter, opt::OptimizationState, params::OptimizationParams, ir::IRCode, @nospecialize(result))
201-
def = opt.linfo.def
202-
nargs = Int(opt.nargs) - 1
201+
(; src, nargs, linfo) = opt
202+
(; def, specTypes) = linfo
203+
nargs = Int(nargs) - 1
203204

204205
force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta)
205206

@@ -221,7 +222,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
221222
end
222223
end
223224
if proven_pure
224-
for fl in opt.src.slotflags
225+
for fl in src.slotflags
225226
if (fl & SLOT_USEDUNDEF) != 0
226227
proven_pure = false
227228
break
@@ -230,7 +231,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
230231
end
231232
end
232233
if proven_pure
233-
opt.src.pure = true
234+
src.pure = true
234235
end
235236

236237
if proven_pure
@@ -243,7 +244,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
243244
if !(isa(result, Const) && !is_inlineable_constant(result.val))
244245
opt.const_api = true
245246
end
246-
force_noinline || (opt.src.inlineable = true)
247+
force_noinline || (src.inlineable = true)
247248
end
248249
end
249250

@@ -252,7 +253,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
252253
# determine and cache inlineability
253254
union_penalties = false
254255
if !force_noinline
255-
sig = unwrap_unionall(opt.linfo.specTypes)
256+
sig = unwrap_unionall(specTypes)
256257
if isa(sig, DataType) && sig.name === Tuple.name
257258
for P in sig.parameters
258259
P = unwrap_unionall(P)
@@ -264,25 +265,25 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
264265
else
265266
force_noinline = true
266267
end
267-
if !opt.src.inlineable && result === Union{}
268+
if !src.inlineable && result === Union{}
268269
force_noinline = true
269270
end
270271
end
271272
if force_noinline
272-
opt.src.inlineable = false
273+
src.inlineable = false
273274
elseif isa(def, Method)
274-
if opt.src.inlineable && isdispatchtuple(opt.linfo.specTypes)
275+
if src.inlineable && isdispatchtuple(specTypes)
275276
# obey @inline declaration if a dispatch barrier would not help
276277
else
277278
bonus = 0
278279
if result Tuple && !isconcretetype(widenconst(result))
279280
bonus = params.inline_tupleret_bonus
280281
end
281-
if opt.src.inlineable
282+
if src.inlineable
282283
# For functions declared @inline, increase the cost threshold 20x
283284
bonus += params.inline_cost_threshold*19
284285
end
285-
opt.src.inlineable = isinlineable(def, opt, params, union_penalties, bonus)
286+
src.inlineable = isinlineable(def, opt, params, union_penalties, bonus)
286287
end
287288
end
288289

0 commit comments

Comments
 (0)