Skip to content

Commit af9e6e3

Browse files
authored
optimize abstract_call_gf_by_type (JuliaLang#56572)
Combines many allocations into one and types them for better inference
1 parent c5899c2 commit af9e6e3

File tree

1 file changed

+120
-93
lines changed

1 file changed

+120
-93
lines changed

Compiler/src/abstractinterpretation.jl

Lines changed: 120 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,76 @@ function propagate_conditional(rt::InterConditional, cond::Conditional)
3838
return Conditional(cond.slot, new_thentype, new_elsetype)
3939
end
4040

41-
function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
41+
mutable struct SafeBox{T}
42+
x::T
43+
SafeBox{T}(x::T) where T = new{T}(x)
44+
SafeBox(@nospecialize x) = new{Any}(x)
45+
end
46+
getindex(box::SafeBox) = box.x
47+
setindex!(box::SafeBox{T}, x::T) where T = setfield!(box, :x, x)
48+
49+
struct FailedMethodMatch
50+
reason::String
51+
end
52+
53+
struct MethodMatchTarget
54+
match::MethodMatch
55+
edges::Vector{Union{Nothing,CodeInstance}}
56+
edge_idx::Int
57+
end
58+
59+
struct MethodMatches
60+
applicable::Vector{MethodMatchTarget}
61+
info::MethodMatchInfo
62+
valid_worlds::WorldRange
63+
end
64+
any_ambig(result::MethodLookupResult) = result.ambig
65+
any_ambig(info::MethodMatchInfo) = any_ambig(info.results)
66+
any_ambig(m::MethodMatches) = any_ambig(m.info)
67+
fully_covering(info::MethodMatchInfo) = info.fullmatch
68+
fully_covering(m::MethodMatches) = fully_covering(m.info)
69+
70+
struct UnionSplitMethodMatches
71+
applicable::Vector{MethodMatchTarget}
72+
applicable_argtypes::Vector{Vector{Any}}
73+
info::UnionSplitInfo
74+
valid_worlds::WorldRange
75+
end
76+
any_ambig(info::UnionSplitInfo) = any(any_ambig, info.split)
77+
any_ambig(m::UnionSplitMethodMatches) = any_ambig(m.info)
78+
fully_covering(info::UnionSplitInfo) = all(fully_covering, info.split)
79+
fully_covering(m::UnionSplitMethodMatches) = fully_covering(m.info)
80+
81+
nmatches(info::MethodMatchInfo) = length(info.results)
82+
function nmatches(info::UnionSplitInfo)
83+
n = 0
84+
for mminfo in info.split
85+
n += nmatches(mminfo)
86+
end
87+
return n
88+
end
89+
90+
# intermediate state for computing gfresult
91+
mutable struct CallInferenceState
92+
inferidx::Int
93+
rettype
94+
exctype
95+
all_effects::Effects
96+
const_results::Union{Nothing,Vector{Union{Nothing,ConstResult}}} # keeps the results of inference with the extended lattice elements (if happened)
97+
conditionals::Union{Nothing,Tuple{Vector{Any},Vector{Any}}} # keeps refinement information of call argument types when the return type is boolean
98+
slotrefinements::Union{Nothing,Vector{Any}} # keeps refinement information on slot types obtained from call signature
99+
100+
# some additional fields for untyped objects (just to avoid capturing)
101+
func
102+
matches::Union{MethodMatches,UnionSplitMethodMatches}
103+
function CallInferenceState(@nospecialize(func), matches::Union{MethodMatches,UnionSplitMethodMatches})
104+
return new(#=inferidx=#1, #=rettype=#Bottom, #=exctype=#Bottom, #=all_effects=#EFFECTS_TOTAL,
105+
#=const_results=#nothing, #=conditionals=#nothing, #=slotrefinements=#nothing,
106+
func, matches)
107+
end
108+
end
109+
110+
function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(func),
42111
arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype),
43112
sv::AbsIntState, max_methods::Int)
44113
𝕃ₚ, 𝕃ᵢ = ipo_lattice(interp), typeinf_lattice(interp)
@@ -50,12 +119,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
50119
return Future(CallMeta(Any, Any, Effects(), NoCallInfo()))
51120
end
52121

53-
(; valid_worlds, applicable, info) = matches
122+
(; valid_worlds, applicable) = matches
54123
update_valid_age!(sv, valid_worlds) # need to record the negative world now, since even if we don't generate any useful information, inlining might want to add an invoke edge and it won't have this information anymore
55124
if bail_out_toplevel_call(interp, sv)
56-
napplicable = length(applicable)
125+
local napplicable = length(applicable)
57126
for i = 1:napplicable
58-
sig = applicable[i].match.spec_types
127+
local sig = applicable[i].match.spec_types
59128
if !isdispatchtuple(sig)
60129
# only infer fully concrete call sites in top-level expressions (ignoring even isa_compileable_sig matches)
61130
add_remark!(interp, sv, "Refusing to infer non-concrete call site in top-level expression")
@@ -66,26 +135,17 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
66135

67136
# final result
68137
gfresult = Future{CallMeta}()
69-
# intermediate work for computing gfresult
70-
rettype = exctype = Bottom
71-
conditionals = nothing # keeps refinement information of call argument types when the return type is boolean
72-
const_results = nothing # or const_results::Vector{Union{Nothing,ConstResult}} if any const results are available
73-
fargs = arginfo.fargs
74-
all_effects = EFFECTS_TOTAL
75-
slotrefinements = nothing # keeps refinement information on slot types obtained from call signature
138+
state = CallInferenceState(func, matches)
76139

77140
# split the for loop off into a function, so that we can pause and restart it at will
78-
i::Int = 1
79-
f = Core.Box(f)
80-
atype = Core.Box(atype)
81141
function infercalls(interp, sv)
82142
local napplicable = length(applicable)
83143
local multiple_matches = napplicable > 1
84-
while i <= napplicable
85-
(; match, edges, edge_idx) = applicable[i]
86-
method = match.method
87-
sig = match.spec_types
88-
if bail_out_call(interp, InferenceLoopState(rettype, all_effects), sv)
144+
while state.inferidx <= napplicable
145+
(; match, edges, edge_idx) = applicable[state.inferidx]
146+
local method = match.method
147+
local sig = match.spec_types
148+
if bail_out_call(interp, InferenceLoopState(state.rettype, state.all_effects), sv)
89149
add_remark!(interp, sv, "Call inference reached maximally imprecise information: bailing on doing more abstract inference.")
90150
break
91151
end
@@ -108,10 +168,11 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
108168
this_exct = exct
109169
# try constant propagation with argtypes for this match
110170
# this is in preparation for inlining, or improving the return result
111-
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
112-
this_arginfo = ArgInfo(fargs, this_argtypes)
171+
local matches = state.matches
172+
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[state.inferidx]
173+
this_arginfo = ArgInfo(arginfo.fargs, this_argtypes)
113174
const_call_result = abstract_call_method_with_const_args(interp,
114-
mresult[], f.contents, this_arginfo, si, match, sv)
175+
mresult[], state.func, this_arginfo, si, match, sv)
115176
const_result = volatile_inf_result
116177
if const_call_result !== nothing
117178
this_const_conditional = ignorelimited(const_call_result.rt)
@@ -146,12 +207,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
146207
end
147208
end
148209

149-
all_effects = merge_effects(all_effects, effects)
210+
state.all_effects = merge_effects(state.all_effects, effects)
150211
if const_result !== nothing
212+
local const_results = state.const_results
151213
if const_results === nothing
152-
const_results = fill!(Vector{Union{Nothing,ConstResult}}(undef, napplicable), nothing)
214+
const_results = state.const_results = fill!(Vector{Union{Nothing,ConstResult}}(undef, napplicable), nothing)
153215
end
154-
const_results[i] = const_result
216+
const_results[state.inferidx] = const_result
155217
end
156218
@assert !(this_conditional isa Conditional || this_rt isa MustAlias) "invalid lattice element returned from inter-procedural context"
157219
if can_propagate_conditional(this_conditional, argtypes)
@@ -162,12 +224,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
162224
this_rt = this_conditional
163225
end
164226

165-
rettype = rettype ₚ this_rt
166-
exctype = exctype ₚ this_exct
167-
if has_conditional(𝕃ₚ, sv) && this_conditional !== Bottom && is_lattice_bool(𝕃ₚ, rettype) && fargs !== nothing
227+
state.rettype = state.rettype ₚ this_rt
228+
state.exctype = state.exctype ₚ this_exct
229+
if has_conditional(𝕃ₚ, sv) && this_conditional !== Bottom && is_lattice_bool(𝕃ₚ, state.rettype) && arginfo.fargs !== nothing
230+
local conditionals = state.conditionals
168231
if conditionals === nothing
169-
conditionals = Any[Bottom for _ in 1:length(argtypes)],
170-
Any[Bottom for _ in 1:length(argtypes)]
232+
conditionals = state.conditionals = (
233+
Any[Bottom for _ in 1:length(argtypes)],
234+
Any[Bottom for _ in 1:length(argtypes)])
171235
end
172236
for i = 1:length(argtypes)
173237
cnd = conditional_argtype(𝕃ᵢ, this_conditional, match.spec_types, argtypes, i)
@@ -177,7 +241,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
177241
end
178242
edges[edge_idx] = edge
179243

180-
i += 1
244+
state.inferidx += 1
181245
return true
182246
end # function handle1
183247
if isready(mresult) && handle1(interp, sv)
@@ -188,30 +252,33 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
188252
end
189253
end # while
190254

191-
seenall = i > napplicable
255+
seenall = state.inferidx > napplicable
256+
retinfo = state.matches.info
192257
if seenall # small optimization to skip some work that is already implied
258+
local const_results = state.const_results
193259
if const_results !== nothing
194-
@assert napplicable == nmatches(info) == length(const_results)
195-
info = ConstCallInfo(info, const_results)
260+
@assert napplicable == nmatches(retinfo) == length(const_results)
261+
retinfo = ConstCallInfo(retinfo, const_results)
196262
end
197-
if !fully_covering(matches) || any_ambig(matches)
263+
if !fully_covering(state.matches) || any_ambig(state.matches)
198264
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
199-
all_effects = Effects(all_effects; nothrow=false)
200-
exctype = exctype ₚ MethodError
265+
state.all_effects = Effects(state.all_effects; nothrow=false)
266+
state.exctype = state.exctype ₚ MethodError
201267
end
268+
local fargs = arginfo.fargs
202269
if sv isa InferenceState && fargs !== nothing
203-
slotrefinements = collect_slot_refinements(𝕃ᵢ, applicable, argtypes, fargs, sv)
270+
state.slotrefinements = collect_slot_refinements(𝕃ᵢ, applicable, argtypes, fargs, sv)
204271
end
205-
rettype = from_interprocedural!(interp, rettype, sv, arginfo, conditionals)
206-
if call_result_unused(si) && !(rettype === Bottom)
272+
state.rettype = from_interprocedural!(interp, state.rettype, sv, arginfo, state.conditionals)
273+
if call_result_unused(si) && !(state.rettype === Bottom)
207274
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
208275
# We're mainly only here because the optimizer might want this code,
209276
# but we ourselves locally don't typically care about it locally
210277
# (beyond checking if it always throws).
211278
# So avoid adding an edge, since we don't want to bother attempting
212279
# to improve our result even if it does change (to always throw),
213280
# and avoid keeping track of a more complex result type.
214-
rettype = Any
281+
state.rettype = Any
215282
end
216283
# if from_interprocedural added any pclimitations to the set inherited from the arguments,
217284
# some of those may be part of our cycles, so those can be deleted now
@@ -230,23 +297,24 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
230297
end
231298
else
232299
# there is unanalyzed candidate, widen type and effects to the top
233-
rettype = exctype = Any
234-
all_effects = Effects()
235-
const_results = nothing
300+
state.rettype = state.exctype = Any
301+
state.all_effects = Effects()
302+
state.const_results = nothing
236303
end
237304

238305
# Also considering inferring the compilation signature for this method, so
239306
# it is available to the compiler in case it ends up needing it for the invoke.
240-
if isa(sv, InferenceState) && infer_compilation_signature(interp) && (!is_removable_if_unused(all_effects) || !call_result_unused(si))
241-
i = 1
307+
if (isa(sv, InferenceState) && infer_compilation_signature(interp) &&
308+
(!is_removable_if_unused(state.all_effects) || !call_result_unused(si)))
309+
inferidx = SafeBox{Int}(1)
242310
function infercalls2(interp, sv)
243311
local napplicable = length(applicable)
244312
local multiple_matches = napplicable > 1
245-
while i <= napplicable
246-
(; match, edges, edge_idx) = applicable[i]
247-
i += 1
248-
method = match.method
249-
sig = match.spec_types
313+
while inferidx[] <= napplicable
314+
(; match, edges, edge_idx) = applicable[inferidx[]]
315+
inferidx[] += 1
316+
local method = match.method
317+
local sig = match.spec_types
250318
mi = specialize_method(match; preexisting=true)
251319
if mi === nothing || !const_prop_methodinstance_heuristic(interp, mi, arginfo, sv)
252320
csig = get_compileable_sig(method, sig, match.sparams)
@@ -265,55 +333,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
265333
infercalls2(interp, sv) || push!(sv.tasks, infercalls2)
266334
end
267335

268-
gfresult[] = CallMeta(rettype, exctype, all_effects, info, slotrefinements)
336+
gfresult[] = CallMeta(state.rettype, state.exctype, state.all_effects, retinfo, state.slotrefinements)
269337
return true
270338
end # function infercalls
271339
# start making progress on the first call
272340
infercalls(interp, sv) || push!(sv.tasks, infercalls)
273341
return gfresult
274342
end
275343

276-
struct FailedMethodMatch
277-
reason::String
278-
end
279-
280-
struct MethodMatchTarget
281-
match::MethodMatch
282-
edges::Vector{Union{Nothing,CodeInstance}}
283-
edge_idx::Int
284-
end
285-
286-
struct MethodMatches
287-
applicable::Vector{MethodMatchTarget}
288-
info::MethodMatchInfo
289-
valid_worlds::WorldRange
290-
end
291-
any_ambig(result::MethodLookupResult) = result.ambig
292-
any_ambig(info::MethodMatchInfo) = any_ambig(info.results)
293-
any_ambig(m::MethodMatches) = any_ambig(m.info)
294-
fully_covering(info::MethodMatchInfo) = info.fullmatch
295-
fully_covering(m::MethodMatches) = fully_covering(m.info)
296-
297-
struct UnionSplitMethodMatches
298-
applicable::Vector{MethodMatchTarget}
299-
applicable_argtypes::Vector{Vector{Any}}
300-
info::UnionSplitInfo
301-
valid_worlds::WorldRange
302-
end
303-
any_ambig(info::UnionSplitInfo) = any(any_ambig, info.split)
304-
any_ambig(m::UnionSplitMethodMatches) = any_ambig(m.info)
305-
fully_covering(info::UnionSplitInfo) = all(fully_covering, info.split)
306-
fully_covering(m::UnionSplitMethodMatches) = fully_covering(m.info)
307-
308-
nmatches(info::MethodMatchInfo) = length(info.results)
309-
function nmatches(info::UnionSplitInfo)
310-
n = 0
311-
for mminfo in info.split
312-
n += nmatches(mminfo)
313-
end
314-
return n
315-
end
316-
317344
function find_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any}, @nospecialize(atype);
318345
max_union_splitting::Int = InferenceParams(interp).max_union_splitting,
319346
max_methods::Int = InferenceParams(interp).max_methods)

0 commit comments

Comments
 (0)