Skip to content

Commit 63f5b8a

Browse files
authored
optimizer: clean up query interfaces (#43324)
- unify `compact_exprtype` and `argextype` - remove redundant arguments - unify `is_known_call` definitions and improve the precision of `is_known_call(..., ::IRCode)` (by using `singleton_type`)
1 parent f60bfd1 commit 63f5b8a

File tree

9 files changed

+186
-217
lines changed

9 files changed

+186
-217
lines changed

base/compiler/optimize.jl

Lines changed: 135 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
186186
return false
187187
end
188188
if isa(stmt, GotoIfNot)
189-
t = argextype(stmt.cond, ir, ir.sptypes)
189+
t = argextype(stmt.cond, ir)
190190
return !(t Bool)
191191
end
192192
if isa(stmt, Expr)
@@ -195,6 +195,127 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
195195
return true
196196
end
197197

198+
"""
199+
stmt_effect_free(stmt, rt, src::Union{IRCode,IncrementalCompact})
200+
201+
Determine whether a `stmt` is "side-effect-free", i.e. may be removed if it has no uses.
202+
"""
203+
function stmt_effect_free(@nospecialize(stmt), @nospecialize(rt), src::Union{IRCode,IncrementalCompact})
204+
isa(stmt, PiNode) && return true
205+
isa(stmt, PhiNode) && return true
206+
isa(stmt, ReturnNode) && return false
207+
isa(stmt, GotoNode) && return false
208+
isa(stmt, GotoIfNot) && return false
209+
isa(stmt, Slot) && return false # Slots shouldn't occur in the IR at this point, but let's be defensive here
210+
isa(stmt, GlobalRef) && return isdefined(stmt.mod, stmt.name)
211+
if isa(stmt, Expr)
212+
(; head, args) = stmt
213+
if head === :static_parameter
214+
etyp = (isa(src, IRCode) ? src.sptypes : src.ir.sptypes)[args[1]::Int]
215+
# if we aren't certain enough about the type, it might be an UndefVarError at runtime
216+
return isa(etyp, Const)
217+
end
218+
if head === :call
219+
f = argextype(args[1], src)
220+
f = singleton_type(f)
221+
f === nothing && return false
222+
is_return_type(f) && return true
223+
if isa(f, IntrinsicFunction)
224+
intrinsic_effect_free_if_nothrow(f) || return false
225+
return intrinsic_nothrow(f,
226+
Any[argextype(args[i], src) for i = 2:length(args)])
227+
end
228+
contains_is(_PURE_BUILTINS, f) && return true
229+
contains_is(_PURE_OR_ERROR_BUILTINS, f) || return false
230+
rt === Bottom && return false
231+
return _builtin_nothrow(f, Any[argextype(args[i], src) for i = 2:length(args)], rt)
232+
elseif head === :new
233+
typ = argextype(args[1], src)
234+
# `Expr(:new)` of unknown type could raise arbitrary TypeError.
235+
typ, isexact = instanceof_tfunc(typ)
236+
isexact || return false
237+
isconcretedispatch(typ) || return false
238+
typ = typ::DataType
239+
fieldcount(typ) >= length(args) - 1 || return false
240+
for fld_idx in 1:(length(args) - 1)
241+
eT = argextype(args[fld_idx + 1], src)
242+
fT = fieldtype(typ, fld_idx)
243+
eT fT || return false
244+
end
245+
return true
246+
elseif head === :new_opaque_closure
247+
length(args) < 5 && return false
248+
typ = argextype(args[1], src)
249+
typ, isexact = instanceof_tfunc(typ)
250+
isexact || return false
251+
typ Tuple || return false
252+
isva = argextype(args[2], src)
253+
rt_lb = argextype(args[3], src)
254+
rt_ub = argextype(args[4], src)
255+
src = argextype(args[5], src)
256+
if !(isva Bool && rt_lb Type && rt_ub Type && src Method)
257+
return false
258+
end
259+
return true
260+
elseif head === :isdefined || head === :the_exception || head === :copyast || head === :inbounds || head === :boundscheck
261+
return true
262+
else
263+
# e.g. :loopinfo
264+
return false
265+
end
266+
end
267+
return true
268+
end
269+
270+
"""
271+
argextype(x, src::Union{IRCode,IncrementalCompact}) -> t
272+
argextype(x, src::CodeInfo, sptypes::Vector{Any}) -> t
273+
274+
Return the type of value `x` in the context of inferred source `src`.
275+
Note that `t` might be an extended lattice element.
276+
Use `widenconst(t)` to get the native Julia type of `x`.
277+
"""
278+
argextype(@nospecialize(x), ir::IRCode, sptypes::Vector{Any} = ir.sptypes) =
279+
argextype(x, ir, sptypes, ir.argtypes)
280+
function argextype(@nospecialize(x), compact::IncrementalCompact, sptypes::Vector{Any} = compact.ir.sptypes)
281+
isa(x, AnySSAValue) && return types(compact)[x]
282+
return argextype(x, compact, sptypes, compact.ir.argtypes)
283+
end
284+
argextype(@nospecialize(x), src::CodeInfo, sptypes::Vector{Any}) = argextype(x, src, sptypes, src.slottypes::Vector{Any})
285+
function argextype(
286+
@nospecialize(x), src::Union{IRCode,IncrementalCompact,CodeInfo},
287+
sptypes::Vector{Any}, slottypes::Vector{Any})
288+
if isa(x, Expr)
289+
if x.head === :static_parameter
290+
return sptypes[x.args[1]::Int]
291+
elseif x.head === :boundscheck
292+
return Bool
293+
elseif x.head === :copyast
294+
return argextype(x.args[1], src, sptypes, slottypes)
295+
end
296+
@assert false "argextype only works on argument-position values"
297+
elseif isa(x, SlotNumber)
298+
return slottypes[x.id]
299+
elseif isa(x, TypedSlot)
300+
return x.typ
301+
elseif isa(x, SSAValue)
302+
return abstract_eval_ssavalue(x, src)
303+
elseif isa(x, Argument)
304+
return slottypes[x.n]
305+
elseif isa(x, QuoteNode)
306+
return Const(x.value)
307+
elseif isa(x, GlobalRef)
308+
return abstract_eval_global(x.mod, x.name)
309+
elseif isa(x, PhiNode)
310+
return Any
311+
elseif isa(x, PiNode)
312+
return x.typ
313+
else
314+
return Const(x)
315+
end
316+
end
317+
abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s]
318+
198319
# compute inlining cost and sideeffects
199320
function finish(interp::AbstractInterpreter, opt::OptimizationState, params::OptimizationParams, ir::IRCode, @nospecialize(result))
200321
(; src, linfo) = opt
@@ -214,7 +335,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
214335
for i in 1:length(ir.stmts)
215336
node = ir.stmts[i]
216337
stmt = node[:inst]
217-
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, node[:type], ir, ir.sptypes)
338+
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, node[:type], ir)
218339
proven_pure = false
219340
break
220341
end
@@ -432,20 +553,19 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
432553
isknowntype(@nospecialize T) = (T === Union{}) || isa(T, Const) || isconcretetype(widenconst(T))
433554

434555
function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any},
435-
slottypes::Vector{Any}, union_penalties::Bool,
436-
params::OptimizationParams, error_path::Bool = false)
556+
union_penalties::Bool, params::OptimizationParams, error_path::Bool = false)
437557
head = ex.head
438558
if is_meta_expr_head(head)
439559
return 0
440560
elseif head === :call
441561
farg = ex.args[1]
442-
ftyp = argextype(farg, src, sptypes, slottypes)
562+
ftyp = argextype(farg, src, sptypes)
443563
if ftyp === IntrinsicFunction && farg isa SSAValue
444564
# if this comes from code that was already inlined into another function,
445565
# Consts have been widened. try to recover in simple cases.
446566
farg = isa(src, CodeInfo) ? src.code[farg.id] : src.stmts[farg.id][:inst]
447567
if isa(farg, GlobalRef) || isa(farg, QuoteNode) || isa(farg, IntrinsicFunction) || isexpr(farg, :static_parameter)
448-
ftyp = argextype(farg, src, sptypes, slottypes)
568+
ftyp = argextype(farg, src, sptypes)
449569
end
450570
end
451571
f = singleton_type(ftyp)
@@ -467,15 +587,15 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
467587
# return plus_saturate(argcost, isknowntype(extyp) ? 1 : params.inline_nonleaf_penalty)
468588
return 0
469589
elseif (f === Core.arrayref || f === Core.const_arrayref || f === Core.arrayset) && length(ex.args) >= 3
470-
atyp = argextype(ex.args[3], src, sptypes, slottypes)
590+
atyp = argextype(ex.args[3], src, sptypes)
471591
return isknowntype(atyp) ? 4 : error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
472-
elseif f === typeassert && isconstType(widenconst(argextype(ex.args[3], src, sptypes, slottypes)))
592+
elseif f === typeassert && isconstType(widenconst(argextype(ex.args[3], src, sptypes)))
473593
return 1
474594
elseif f === Core.isa
475595
# If we're in a union context, we penalize type computations
476596
# on union types. In such cases, it is usually better to perform
477597
# union splitting on the outside.
478-
if union_penalties && isa(argextype(ex.args[2], src, sptypes, slottypes), Union)
598+
if union_penalties && isa(argextype(ex.args[2], src, sptypes), Union)
479599
return params.inline_nonleaf_penalty
480600
end
481601
end
@@ -487,7 +607,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
487607
end
488608
return T_FFUNC_COST[fidx]
489609
end
490-
extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes, slottypes)
610+
extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes)
491611
if extyp === Union{}
492612
return 0
493613
end
@@ -498,7 +618,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
498618
# run-time of the function, we omit them from
499619
# consideration. This way, non-inlined error branches do not
500620
# prevent inlining.
501-
extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes, slottypes)
621+
extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes)
502622
return extyp === Union{} ? 0 : 20
503623
elseif head === :(=)
504624
if ex.args[1] isa GlobalRef
@@ -508,7 +628,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
508628
end
509629
a = ex.args[2]
510630
if a isa Expr
511-
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, union_penalties, params, error_path))
631+
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, union_penalties, params, error_path))
512632
end
513633
return cost
514634
elseif head === :copyast
@@ -524,11 +644,11 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
524644
end
525645

526646
function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any},
527-
slottypes::Vector{Any}, union_penalties::Bool, params::OptimizationParams)
647+
union_penalties::Bool, params::OptimizationParams)
528648
thiscost = 0
529649
dst(tgt) = isa(src, IRCode) ? first(src.cfg.blocks[tgt].stmts) : tgt
530650
if stmt isa Expr
531-
thiscost = statement_cost(stmt, line, src, sptypes, slottypes, union_penalties, params,
651+
thiscost = statement_cost(stmt, line, src, sptypes, union_penalties, params,
532652
is_stmt_throw_block(isa(src, IRCode) ? src.stmts.flag[line] : src.ssaflags[line]))::Int
533653
elseif stmt isa GotoNode
534654
# loops are generally always expensive
@@ -546,7 +666,7 @@ function inline_worthy(ir::IRCode,
546666
bodycost::Int = 0
547667
for line = 1:length(ir.stmts)
548668
stmt = ir.stmts[line][:inst]
549-
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, ir.argtypes, union_penalties, params)
669+
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, union_penalties, params)
550670
bodycost = plus_saturate(bodycost, thiscost)
551671
bodycost > cost_threshold && return false
552672
end
@@ -558,7 +678,6 @@ function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeI
558678
for line = 1:length(body)
559679
stmt = body[line]
560680
thiscost = statement_or_branch_cost(stmt, line, src, sptypes,
561-
src isa CodeInfo ? src.slottypes : src.argtypes,
562681
unionpenalties, params)
563682
cost[line] = thiscost
564683
if thiscost > maxcost
@@ -568,14 +687,6 @@ function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeI
568687
return maxcost
569688
end
570689

571-
function is_known_call(e::Expr, @nospecialize(func), src, sptypes::Vector{Any}, slottypes::Vector{Any} = EMPTY_SLOTTYPES)
572-
if e.head !== :call
573-
return false
574-
end
575-
f = argextype(e.args[1], src, sptypes, slottypes)
576-
return isa(f, Const) && f.val === func
577-
end
578-
579690
function renumber_ir_elements!(body::Vector{Any}, changemap::Vector{Int})
580691
return renumber_ir_elements!(body, changemap, changemap)
581692
end

base/compiler/ssair/driver.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ include("compiler/ssair/basicblock.jl")
1414
include("compiler/ssair/domtree.jl")
1515
include("compiler/ssair/ir.jl")
1616
include("compiler/ssair/slot2ssa.jl")
17-
include("compiler/ssair/queries.jl")
1817
include("compiler/ssair/passes.jl")
1918
include("compiler/ssair/inlining.jl")
2019
include("compiler/ssair/verify.jl")

base/compiler/ssair/inlining.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
371371
return_value = SSAValue(idx′)
372372
inline_compact[idx′] = val
373373
inline_compact.result[idx′][:type] =
374-
compact_exprtype(isa(val, Argument) || isa(val, Expr) ? compact : inline_compact, val)
374+
argextype(val, isa(val, Argument) || isa(val, Expr) ? compact : inline_compact)
375375
break
376376
end
377377
inline_compact[idx′] = stmt′
@@ -400,7 +400,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
400400
if isa(val, GlobalRef) || isa(val, Expr)
401401
stmt′ = val
402402
inline_compact.result[idx′][:type] =
403-
compact_exprtype(isa(val, Expr) ? compact : inline_compact, val)
403+
argextype(val, isa(val, Expr) ? compact : inline_compact)
404404
insert_node_here!(inline_compact, NewInstruction(GotoNode(post_bb_id),
405405
Any, compact.result[idx′][:line]),
406406
true)
@@ -435,7 +435,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
435435
return_value = pn.values[1]
436436
else
437437
return_value = insert_node_here!(compact,
438-
NewInstruction(pn, compact_exprtype(compact, SSAValue(idx)), compact.result[idx][:line]))
438+
NewInstruction(pn, argextype(SSAValue(idx), compact), compact.result[idx][:line]))
439439
end
440440
end
441441
return_value
@@ -580,7 +580,7 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect
580580
for aidx in 1:length(argexprs)
581581
aexpr = argexprs[aidx]
582582
if isa(aexpr, Expr) || isa(aexpr, GlobalRef)
583-
ninst = effect_free(NewInstruction(aexpr, compact_exprtype(compact, aexpr), compact.result[idx][:line]))
583+
ninst = effect_free(NewInstruction(aexpr, argextype(aexpr, compact), compact.result[idx][:line]))
584584
argexprs[aidx] = insert_node_here!(compact, ninst)
585585
end
586586
end
@@ -886,7 +886,7 @@ function inline_splatnew!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(rt))
886886
if nf isa Const
887887
eargs = stmt.args
888888
tup = eargs[2]
889-
tt = argextype(tup, ir, ir.sptypes)
889+
tt = argextype(tup, ir)
890890
tnf = nfields_tfunc(tt)
891891
# TODO: hoisting this tnf.val === nf.val check into codegen
892892
# would enable us to almost always do this transform
@@ -908,15 +908,15 @@ end
908908

909909
function call_sig(ir::IRCode, stmt::Expr)
910910
isempty(stmt.args) && return nothing
911-
ft = argextype(stmt.args[1], ir, ir.sptypes)
911+
ft = argextype(stmt.args[1], ir)
912912
has_free_typevars(ft) && return nothing
913913
f = singleton_type(ft)
914914
f === Core.Intrinsics.llvmcall && return nothing
915915
f === Core.Intrinsics.cglobal && return nothing
916916
argtypes = Vector{Any}(undef, length(stmt.args))
917917
argtypes[1] = ft
918918
for i = 2:length(stmt.args)
919-
a = argextype(stmt.args[i], ir, ir.sptypes)
919+
a = argextype(stmt.args[i], ir)
920920
(a === Bottom || isvarargtype(a)) && return nothing
921921
argtypes[i] = a
922922
end
@@ -1025,10 +1025,10 @@ end
10251025

10261026
function narrow_opaque_closure!(ir::IRCode, stmt::Expr, @nospecialize(info), state::InliningState)
10271027
if isa(info, OpaqueClosureCreateInfo)
1028-
lbt = argextype(stmt.args[3], ir, ir.sptypes)
1028+
lbt = argextype(stmt.args[3], ir)
10291029
lb, exact = instanceof_tfunc(lbt)
10301030
exact || return
1031-
ubt = argextype(stmt.args[4], ir, ir.sptypes)
1031+
ubt = argextype(stmt.args[4], ir)
10321032
ub, exact = instanceof_tfunc(ubt)
10331033
exact || return
10341034
# Narrow opaque closure type
@@ -1046,7 +1046,7 @@ end
10461046
# For primitives, we do that right here. For proper calls, we will
10471047
# discover this when we consult the caches.
10481048
function check_effect_free!(ir::IRCode, idx::Int, @nospecialize(stmt), @nospecialize(rt))
1049-
if stmt_effect_free(stmt, rt, ir, ir.sptypes)
1049+
if stmt_effect_free(stmt, rt, ir)
10501050
ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE
10511051
end
10521052
end
@@ -1346,7 +1346,7 @@ end
13461346

13471347
function mk_tuplecall!(compact::IncrementalCompact, args::Vector{Any}, line_idx::Int32)
13481348
e = Expr(:call, TOP_TUPLE, args...)
1349-
etyp = tuple_tfunc(Any[compact_exprtype(compact, args[i]) for i in 1:length(args)])
1349+
etyp = tuple_tfunc(Any[argextype(args[i], compact) for i in 1:length(args)])
13501350
return insert_node_here!(compact, NewInstruction(e, etyp, line_idx))
13511351
end
13521352

base/compiler/ssair/ir.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ function insert_node!(ir::IRCode, pos::Int, inst::NewInstruction, attach_after::
520520
node[:line] = something(inst.line, ir.stmts[pos][:line])
521521
flag = inst.flag
522522
if !inst.effect_free_computed
523-
if stmt_effect_free(inst.stmt, inst.type, ir, ir.sptypes)
523+
if stmt_effect_free(inst.stmt, inst.type, ir)
524524
flag |= IR_FLAG_EFFECT_FREE
525525
end
526526
end
@@ -765,7 +765,7 @@ function insert_node_here!(compact::IncrementalCompact, inst::NewInstruction, re
765765
resize!(compact, result_idx)
766766
end
767767
flag = inst.flag
768-
if !inst.effect_free_computed && stmt_effect_free(inst.stmt, inst.type, compact, compact.ir.sptypes)
768+
if !inst.effect_free_computed && stmt_effect_free(inst.stmt, inst.type, compact)
769769
flag |= IR_FLAG_EFFECT_FREE
770770
end
771771
node = compact.result[result_idx]
@@ -1316,7 +1316,7 @@ function maybe_erase_unused!(
13161316
callback = null_dce_callback)
13171317
stmt = compact.result[idx][:inst]
13181318
stmt === nothing && return false
1319-
if compact_exprtype(compact, SSAValue(idx)) === Bottom
1319+
if argextype(SSAValue(idx), compact) === Bottom
13201320
effect_free = false
13211321
else
13221322
effect_free = compact.result[idx][:flag] & IR_FLAG_EFFECT_FREE != 0
@@ -1466,8 +1466,3 @@ function iterate(x::BBIdxIter, (idx, bb)::Tuple{Int, Int}=(1, 1))
14661466
end
14671467
return (bb, idx), (idx + 1, next_bb)
14681468
end
1469-
1470-
is_known_call(e::Expr, @nospecialize(func), ir::IRCode) =
1471-
is_known_call(e, func, ir, ir.sptypes, ir.argtypes)
1472-
1473-
argextype(@nospecialize(x), ir::IRCode) = argextype(x, ir, ir.sptypes, ir.argtypes)

0 commit comments

Comments
 (0)