@@ -6,18 +6,20 @@ using Base.Meta
66using Base: may_invoke_generator
77
88transform(:: Val , callsite) = callsite
9- function transform(:: Val{:CuFunction} , callsite, callexpr, CI, mi, slottypes; world= get_world_counter())
9+ function transform(:: Val{:CuFunction} , interp, callsite, callexpr, CI, mi, slottypes; world= get_world_counter())
1010 sptypes = sptypes_from_meth_instance(mi)
1111 tt = argextype(callexpr. args[4 ], CI, sptypes, slottypes)
1212 ft = argextype(callexpr. args[3 ], CI, sptypes, slottypes)
1313 isa(tt, Const) || return callsite
14- return Callsite(callsite. id, CuCallInfo(callinfo(Tuple{widenconst(ft), tt. val. parameters... }, Nothing; world)), callsite. head)
14+ sig = Tuple{widenconst(ft), tt. val. parameters... }
15+ return Callsite(callsite. id, CuCallInfo(callinfo(interp, sig, Nothing; world)), callsite. head)
1516end
1617
1718function find_callsites(interp:: AbstractInterpreter , CI:: Union{CodeInfo,IRCode} ,
18- stmt_infos:: Union{Vector{CCCallInfo}, Nothing} , mi :: MethodInstance ,
19+ stmt_infos:: Union{Vector{CCCallInfo}, Nothing} , ci :: CodeInstance ,
1920 slottypes:: Vector{Any} , optimize:: Bool = true , annotate_source:: Bool = false ,
2021 pc2excts:: Union{Nothing,PC2Excts} = nothing )
22+ mi = ci. def
2123 sptypes = sptypes_from_meth_instance(mi)
2224 callsites, sourcenodes = Callsite[], Union{TypedSyntax. MaybeTypedSyntaxNode,Callsite}[]
2325 if isa(CI, IRCode)
@@ -89,7 +91,7 @@ function find_callsites(interp::AbstractInterpreter, CI::Union{CodeInfo,IRCode},
8991 func = args[6 ]
9092 ftype = widenconst(argextype(func, CI, sptypes, slottypes))
9193 sig = Tuple{ftype}
92- callsite = Callsite(id, TaskCallInfo(callinfo(sig, nothing ; world= get_inference_world(interp))), head)
94+ callsite = Callsite(id, TaskCallInfo(callinfo(interp, sig, nothing ; world= get_inference_world(interp))), head)
9395 end
9496 end
9597 end
@@ -101,7 +103,7 @@ function find_callsites(interp::AbstractInterpreter, CI::Union{CodeInfo,IRCode},
101103 ci = get_ci(info)
102104 meth = ci. def. def
103105 if isa(meth, Method) && nameof(meth. module) === :CUDAnative && meth. name === :cufunction
104- callsite = transform(Val(:CuFunction), callsite, c, CI, ci. def, slottypes; world= get_inference_world(interp))
106+ callsite = transform(Val(:CuFunction), interp, callsite, c, CI, ci. def, slottypes; world= get_inference_world(interp))
105107 end
106108 end
107109
@@ -281,7 +283,7 @@ function preprocess_ci!(ir::IRCode, _::MethodInstance, optimize::Bool, config::C
281283 return ir
282284end
283285
284- function callinfo(sig, rt, max_methods= - 1 ; world= get_world_counter())
286+ function callinfo(interp, sig, rt, max_methods= - 1 ; world= get_world_counter())
285287 methds = Base. _methods_by_ftype(sig, max_methods, world)
286288 methds isa Bool && return FailedCallInfo(sig, rt)
287289 length(methds) < 1 && return FailedCallInfo(sig, rt)
@@ -295,7 +297,8 @@ function callinfo(sig, rt, max_methods=-1; world=get_world_counter())
295297 else
296298 mi = specialize_method(meth, atypes, sparams)
297299 if mi != = nothing
298- push!(callinfos, EdgeCallInfo(mi, rt, Effects()))
300+ edge = do_typeinf!(interp, mi)
301+ push!(callinfos, EdgeCallInfo(edge, rt, Effects()))
299302 else
300303 push!(callinfos, FailedCallInfo(sig, rt))
301304 end
@@ -315,7 +318,7 @@ function find_caller_of(interp::AbstractInterpreter, callee::Union{MethodInstanc
315318 for optimize in (true , false )
316319 (; src, rt, infos, slottypes) = lookup(interp′, codeinst, optimize)
317320 src = preprocess_ci!(src, caller, optimize, CONFIG)
318- callsites, _ = find_callsites(interp′, src, infos, caller , slottypes, optimize)
321+ callsites, _ = find_callsites(interp′, src, infos, codeinst , slottypes, optimize)
319322 callsites = allow_unspecialized ? filter(cs-> maybe_callsite(cs, callee), callsites) :
320323 filter(cs-> is_callsite(cs, callee), callsites)
321324 foreach(cs -> add_sourceline!(locs, src, cs. id, caller), callsites)
0 commit comments