@@ -675,24 +675,17 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
675675 new_stmt = Expr (:call , argexprs[2 ], def, state... )
676676 state1 = insert_node! (ir, idx, NewInstruction (new_stmt, call. rt))
677677 new_sig = with_atype (call_sig (ir, new_stmt):: Signature )
678- info = call. info
679- handled = false
680- if isa (info, ConstCallInfo)
681- if maybe_handle_const_call! (
682- ir, state1. id, new_stmt, info, new_sig,
683- istate, false , todo)
684- handled = true
685- else
686- info = info. call
687- end
688- end
689- if ! handled && (isa (info, MethodMatchInfo) || isa (info, UnionSplitInfo))
690- info = isa (info, MethodMatchInfo) ?
691- MethodMatchInfo[info] : info. matches
678+ new_info = call. info
679+ if isa (new_info, ConstCallInfo)
680+ handle_const_call! (
681+ ir, state1. id, new_stmt, new_info,
682+ new_sig, istate, todo)
683+ elseif isa (new_info, MethodMatchInfo) || isa (new_info, UnionSplitInfo)
684+ new_infos = isa (new_info, MethodMatchInfo) ? MethodMatchInfo[new_info] : new_info. matches
692685 # See if we can inline this call to `iterate`
693686 analyze_single_call! (
694687 ir, todo, state1. id, new_stmt,
695- new_sig, info , istate)
688+ new_sig, new_infos , istate)
696689 end
697690 if i != length (thisarginfo. each)
698691 valT = getfield_tfunc (call. rt, Const (1 ))
@@ -1200,49 +1193,38 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta
12001193 return sig
12011194end
12021195
1203- # TODO inline non-`isdispatchtuple`, union-split callsites
1196+ # TODO inline non-`isdispatchtuple`, union-split callsites?
12041197function analyze_single_call! (
12051198 ir:: IRCode , todo:: Vector{Pair{Int, Any}} , idx:: Int , @nospecialize (stmt),
12061199 (; atypes, atype):: Signature , infos:: Vector{MethodMatchInfo} , state:: InliningState )
12071200 cases = InliningCase[]
12081201 local signature_union = Bottom
12091202 local only_method = nothing # keep track of whether there is one matching method
1210- local meth
1203+ local meth:: MethodLookupResult
12111204 local fully_covered = true
12121205 for i in 1 : length (infos)
1213- info = infos[i]
1214- meth = info. results
1206+ meth = infos[i]. results
12151207 if meth. ambig
12161208 # Too many applicable methods
12171209 # Or there is a (partial?) ambiguity
1218- return
1210+ return nothing
12191211 elseif length (meth) == 0
12201212 # No applicable methods; try next union split
12211213 continue
1222- elseif length (meth) == 1 && only_method != = false
1223- if only_method === nothing
1224- only_method = meth[1 ]. method
1225- elseif only_method != = meth[1 ]. method
1214+ else
1215+ if length (meth) == 1 && only_method != = false
1216+ if only_method === nothing
1217+ only_method = meth[1 ]. method
1218+ elseif only_method != = meth[1 ]. method
1219+ only_method = false
1220+ end
1221+ else
12261222 only_method = false
12271223 end
1228- else
1229- only_method = false
12301224 end
12311225 for match in meth
1232- spec_types = match. spec_types
1233- signature_union = Union{signature_union, spec_types}
1234- if ! isdispatchtuple (spec_types)
1235- fully_covered = false
1236- continue
1237- end
1238- item = analyze_method! (match, atypes, state)
1239- if item === nothing
1240- fully_covered = false
1241- continue
1242- elseif _any (case-> case. sig === spec_types, cases)
1243- continue
1244- end
1245- push! (cases, InliningCase (spec_types, item))
1226+ signature_union = Union{signature_union, match. spec_types}
1227+ fully_covered &= handle_match! (match, atypes, state, cases)
12461228 end
12471229 end
12481230
@@ -1253,9 +1235,8 @@ function analyze_single_call!(
12531235 if length (infos) > 1
12541236 (metharg, methsp) = ccall (:jl_type_intersection_with_env , Any, (Any, Any),
12551237 atype, only_method. sig):: SimpleVector
1256- match = MethodMatch (metharg, methsp, only_method, true )
1238+ match = MethodMatch (metharg, methsp:: SimpleVector , only_method, true )
12571239 else
1258- meth = meth:: MethodLookupResult
12591240 @assert length (meth) == 1
12601241 match = meth[1 ]
12611242 end
@@ -1268,46 +1249,41 @@ function analyze_single_call!(
12681249 fully_covered = false
12691250 end
12701251
1271- # If we only have one case and that case is fully covered, we may either
1272- # be able to do the inlining now (for constant cases), or push it directly
1273- # onto the todo list
1274- if fully_covered && length (cases) == 1
1275- handle_single_case! (ir, stmt, idx, cases[1 ]. item, false , todo)
1276- elseif length (cases) > 0
1277- push! (todo, idx=> UnionSplit (fully_covered, atype, cases))
1278- end
1279- return nothing
1252+ handle_cases! (ir, idx, stmt, sig, cases, fully_covered, todo)
12801253end
12811254
1282- # try to create `InliningCase`s using constant-prop'ed results
1283- # currently it works only when constant-prop' succeeded for all (union-split) signatures
1284- # TODO use any of constant-prop'ed results, and leave the other unhandled cases to later
1285- # TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out
1286- function maybe_handle_const_call! (
1287- ir :: IRCode , idx :: Int , stmt :: Expr , (; results):: ConstCallInfo , (; atypes, atype) :: Signature ,
1288- state :: InliningState , isinvoke :: Bool , todo :: Vector{Pair{Int, Any}} )
1289- cases = InliningCase[] # TODO avoid this allocation for single cases ?
1255+ # similar to `analyze_single_call!`, but with constant results
1256+ function handle_const_call! (
1257+ ir :: IRCode , idx :: Int , stmt :: Expr , cinfo :: ConstCallInfo ,
1258+ sig :: Signature , state :: InliningState , todo :: Vector{Pair{Int, Any}} )
1259+ (; atypes, atype) = sig
1260+ (; call, results) = cinfo
1261+ infos = isa (call, MethodMatchInfo) ? MethodMatchInfo[call] : call . matches
1262+ cases = InliningCase[]
12901263 local fully_covered = true
12911264 local signature_union = Bottom
1292- for result in results
1293- isa (result, InferenceResult) || return false
1294- (; mi) = item = InliningTodo (result, atypes)
1295- spec_types = mi. specTypes
1296- signature_union = Union{signature_union, spec_types}
1297- if ! isdispatchtuple (spec_types)
1298- fully_covered = false
1299- continue
1300- end
1301- if ! validate_sparams (mi. sparam_vals)
1302- fully_covered = false
1265+ local j = 0
1266+ for i in 1 : length (infos)
1267+ meth = infos[i]. results
1268+ if meth. ambig
1269+ # Too many applicable methods
1270+ # Or there is a (partial?) ambiguity
1271+ return nothing
1272+ elseif length (meth) == 0
1273+ # No applicable methods; try next union split
13031274 continue
13041275 end
1305- state. mi_cache != = nothing && (item = resolve_todo (item, state))
1306- if item === nothing
1307- fully_covered = false
1308- continue
1276+ for match in meth
1277+ j += 1
1278+ result = results[j]
1279+ if result === nothing
1280+ signature_union = Union{signature_union, match. spec_types}
1281+ fully_covered &= handle_match! (match, atypes, state, cases)
1282+ else
1283+ signature_union = Union{signature_union, result. linfo. specTypes}
1284+ fully_covered &= handle_const_result! (result, atypes, state, cases)
1285+ end
13091286 end
1310- push! (cases, InliningCase (spec_types, item))
13111287 end
13121288
13131289 # if the signature is fully covered and there is only one applicable method,
@@ -1316,25 +1292,54 @@ function maybe_handle_const_call!(
13161292 if length (cases) == 0 && length (results) == 1
13171293 (; mi) = item = InliningTodo (results[1 ]:: InferenceResult , atypes)
13181294 state. mi_cache != = nothing && (item = resolve_todo (item, state))
1319- validate_sparams (mi. sparam_vals) || return true
1320- item === nothing && return true
1295+ validate_sparams (mi. sparam_vals) || return nothing
1296+ item === nothing && return nothing
13211297 push! (cases, InliningCase (mi. specTypes, item))
13221298 fully_covered = true
13231299 end
13241300 else
13251301 fully_covered = false
13261302 end
13271303
1304+ handle_cases! (ir, idx, stmt, sig, cases, fully_covered, todo)
1305+ end
1306+
1307+ function handle_match! (
1308+ match:: MethodMatch , argtypes:: Vector{Any} , state:: InliningState ,
1309+ cases:: Vector{InliningCase} )
1310+ spec_types = match. spec_types
1311+ isdispatchtuple (spec_types) || return false
1312+ item = analyze_method! (match, argtypes, state)
1313+ item === nothing && return false
1314+ _any (case-> case. sig === spec_types, cases) && return true
1315+ push! (cases, InliningCase (spec_types, item))
1316+ return true
1317+ end
1318+
1319+ function handle_const_result! (
1320+ result:: InferenceResult , argtypes:: Vector{Any} , state:: InliningState ,
1321+ cases:: Vector{InliningCase} )
1322+ (; mi) = item = InliningTodo (result, argtypes)
1323+ spec_types = mi. specTypes
1324+ isdispatchtuple (spec_types) || return false
1325+ validate_sparams (mi. sparam_vals) || return false
1326+ state. mi_cache != = nothing && (item = resolve_todo (item, state))
1327+ item === nothing && return false
1328+ push! (cases, InliningCase (spec_types, item))
1329+ return true
1330+ end
1331+
1332+ function handle_cases! (ir:: IRCode , idx:: Int , stmt:: Expr , sig:: Signature ,
1333+ cases:: Vector{InliningCase} , fully_covered:: Bool , todo:: Vector{Pair{Int, Any}} )
13281334 # If we only have one case and that case is fully covered, we may either
13291335 # be able to do the inlining now (for constant cases), or push it directly
13301336 # onto the todo list
13311337 if fully_covered && length (cases) == 1
13321338 handle_single_case! (ir, stmt, idx, cases[1 ]. item, isinvoke, todo)
13331339 elseif length (cases) > 0
1334- isinvoke && rewrite_invoke_exprargs! (stmt)
1335- push! (todo, idx=> UnionSplit (fully_covered, atype, cases))
1340+ push! (todo, idx=> UnionSplit (fully_covered, sig. atype, cases))
13361341 end
1337- return true
1342+ return nothing
13381343end
13391344
13401345function handle_const_opaque_closure_call! (
@@ -1371,9 +1376,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13711376 ir. stmts[idx][:flag ] |= IR_FLAG_EFFECT_FREE
13721377 info = info. info
13731378 end
1374-
1375- # Inference determined this couldn't be analyzed. Don't question it.
13761379 if info === false
1380+ # Inference determined this couldn't be analyzed. Don't question it.
13771381 continue
13781382 end
13791383
0 commit comments