Skip to content

Commit 17e0bba

Browse files
authored
optimizer: eliminate safe typeassert calls (JuliaLang#42706)
Adds a very simple optimization pass to eliminate `typeassert` calls. The motivation is, when SROA replaces `getfield` calls with scalar values, then we can often prove `typeassert` whose first operand is a replaced value is no-op: ```julia julia> struct Foo; x; end julia> code_typed((Int,)) do a x1 = Foo(a) x2 = Foo(x1) typeassert(x2.x, Foo).x end |> only |> first CodeInfo( 1 ─ %1 = Main.Foo::Type{Foo} │ %2 = %new(%1, a)::Foo │ Main.typeassert(%2, Main.Foo)::Foo # can be nullified └── return a ) ``` Nullifying `typeassert` helps succeeding (simple) DCE to eliminate dead allocations, and also allows LLVM to do more aggressive DCE to emit simpler code. Here is a simple benchmarking: > sample target code: ```julia julia> function compute(T, n) r = 0 for i in 1:n x1 = T(i) x2 = T(x1) r += (x2.x::T).x::Int end r end compute (generic function with 1 method) julia> struct Foo; x; end julia> mutable struct Bar; x; end ``` > on master ```julia julia> @benchmark compute(Foo, 1000) BenchmarkTools.Trial: 10000 samples with 8 evaluations. Range (min … max): 3.263 μs … 145.828 μs ┊ GC (min … max): 0.00% … 97.14% Time (median): 3.516 μs ┊ GC (median): 0.00% Time (mean ± σ): 4.015 μs ± 3.726 μs ┊ GC (mean ± σ): 3.16% ± 3.46% ▇█▆▄▅▄▄▃▂▁▂▁ ▂ ▇███████████████▇██▇▇█▇▇▆▇▇▇▇▅▆▅▇▇▅██▇▇▆▇▇▇█▇█▇▇▅▆▆▆▆▅▅▅▅▄▄ █ 3.26 μs Histogram: log(frequency) by time 8.52 μs < Memory estimate: 7.64 KiB, allocs estimate: 489. julia> @benchmark compute(Bar, 1000) BenchmarkTools.Trial: 10000 samples with 4 evaluations. Range (min … max): 6.990 μs … 288.079 μs ┊ GC (min … max): 0.00% … 97.03% Time (median): 7.657 μs ┊ GC (median): 0.00% Time (mean ± σ): 9.019 μs ± 9.710 μs ┊ GC (mean ± σ): 4.59% ± 4.28% ▆█▆▄▃▂▂▁▂▃▂▁ ▁ ▁ ██████████████████████▇▇▇▇▇▆██████▇▇█▇▇▇▆▆▆▆▅▆▅▄▄▄▅▄▄▃▄▄▂▄▅ █ 6.99 μs Histogram: log(frequency) by time 20.7 μs < Memory estimate: 23.27 KiB, allocs estimate: 1489. ``` > on this branch ```julia julia> @benchmark compute(Foo, 1000) BenchmarkTools.Trial: 10000 samples with 1000 evaluations. Range (min … max): 1.234 ns … 116.188 ns ┊ GC (min … max): 0.00% … 0.00% Time (median): 1.246 ns ┊ GC (median): 0.00% Time (mean ± σ): 1.307 ns ± 1.444 ns ┊ GC (mean ± σ): 0.00% ± 0.00% █▇ ▂▂▁ ▂ ▁ ██████▇█▇▅▄▆▇▆▁▃▄▁▁▁▁▁▃▁▃▁▁▄▇▅▃▃▃▁▃▄▁▃▃▁▃▁▁▃▁▁▁▄▃▁▃▇███▇▇▇▆ █ 1.23 ns Histogram: log(frequency) by time 1.94 ns < Memory estimate: 0 bytes, allocs estimate: 0. julia> @benchmark compute(Bar, 1000) BenchmarkTools.Trial: 10000 samples with 1000 evaluations. Range (min … max): 1.234 ns … 33.790 ns ┊ GC (min … max): 0.00% … 0.00% Time (median): 1.245 ns ┊ GC (median): 0.00% Time (mean ± σ): 1.297 ns ± 0.677 ns ┊ GC (mean ± σ): 0.00% ± 0.00% █▇ ▃▂▁ ▁ ██████▆▆▅▁▄▅▅▄▁▄▄▄▃▄▃▁▃▁▃▄▃▁▃▁▃▁▁▁▃▃▁▃▃▁▁▁▁▁▁▁▃▁▄█████▇▇▇▇ █ 1.23 ns Histogram: log(frequency) by time 1.96 ns < Memory estimate: 0 bytes, allocs estimate: 0. ``` This `typeassert` elimination would be much more effective if we implement more aggressive SROA based on strong [alias analysis](https://github.com/aviatesk/EscapeAnalysis.jl) and/or [more aggressive Julia-level DCE](JuliaLang#27547). But this change is so simple that I don't think it hurts anything to have it for now.
1 parent a4903fd commit 17e0bba

File tree

5 files changed

+76
-36
lines changed

5 files changed

+76
-36
lines changed

base/compiler/optimize.jl

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -317,31 +317,26 @@ function optimize(interp::AbstractInterpreter, opt::OptimizationState, params::O
317317
end
318318

319319
function run_passes(ci::CodeInfo, sv::OptimizationState)
320-
preserve_coverage = coverage_enabled(sv.mod)
321-
ir = convert_to_ircode(ci, copy_exprargs(ci.code), preserve_coverage, sv)
322-
ir = slot2reg(ir, ci, sv)
323-
#@Base.show ("after_construct", ir)
320+
@timeit "convert" ir = convert_to_ircode(ci, sv)
321+
@timeit "slot2reg" ir = slot2reg(ir, ci, sv)
324322
# TODO: Domsorting can produce an updated domtree - no need to recompute here
325323
@timeit "compact 1" ir = compact!(ir)
326-
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
327-
#@timeit "verify 2" verify_ir(ir)
328-
ir = compact!(ir)
329-
#@Base.show ("before_sroa", ir)
330-
@timeit "SROA" ir = getfield_elim_pass!(ir)
331-
#@Base.show ir.new_nodes
332-
#@Base.show ("after_sroa", ir)
333-
ir = adce_pass!(ir)
334-
#@Base.show ("after_adce", ir)
324+
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
325+
# @timeit "verify 2" verify_ir(ir)
326+
@timeit "compact 2" ir = compact!(ir)
327+
@timeit "SROA" ir = getfield_elim_pass!(ir)
328+
@timeit "ADCE" ir = adce_pass!(ir)
335329
@timeit "type lift" ir = type_lift_pass!(ir)
336330
@timeit "compact 3" ir = compact!(ir)
337-
#@Base.show ir
338331
if JLOptions().debug_level == 2
339332
@timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable))
340333
end
341334
return ir
342335
end
343336

344-
function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, sv::OptimizationState)
337+
function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
338+
code = copy_exprargs(ci.code)
339+
coverage = coverage_enabled(sv.mod)
345340
# Go through and add an unreachable node after every
346341
# Union{} call. Then reindex labels.
347342
idx = 1

base/compiler/ssair/ir.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,3 +1467,8 @@ function iterate(x::BBIdxIter, (idx, bb)::Tuple{Int, Int}=(1, 1))
14671467
end
14681468
return (bb, idx), (idx + 1, next_bb)
14691469
end
1470+
1471+
is_known_call(e::Expr, @nospecialize(func), ir::IRCode) =
1472+
is_known_call(e, func, ir, ir.sptypes, ir.argtypes)
1473+
1474+
argextype(@nospecialize(x), ir::IRCode) = argextype(x, ir, ir.sptypes, ir.argtypes)

base/compiler/ssair/passes.jl

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -449,10 +449,10 @@ function lift_comparison!(compact::IncrementalCompact, idx::Int,
449449
lifted_val = perform_lifting!(compact, visited_phinodes, cmp, lifting_cache, Bool, lifted_leaves, val)
450450
@assert lifted_val !== nothing
451451

452-
#global assertion_counter
453-
#assertion_counter::Int += 1
454-
#insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true)
455-
#return
452+
# global assertion_counter
453+
# assertion_counter::Int += 1
454+
# insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true)
455+
# return
456456
compact[idx] = lifted_val.x
457457
end
458458

@@ -734,17 +734,6 @@ function getfield_elim_pass!(ir::IRCode)
734734
result_t = make_MaybeUndef(result_t)
735735
end
736736

737-
# @Base.show result_t
738-
# @Base.show stmt
739-
# for (k,v) in lifted_leaves
740-
# @Base.show (k, v)
741-
# if isa(k, AnySSAValue)
742-
# @Base.show compact[k]
743-
# end
744-
# if isa(v, RefValue) && isa(v.x, AnySSAValue)
745-
# @Base.show compact[v.x]
746-
# end
747-
# end
748737
val = perform_lifting!(compact, visited_phinodes, field, lifting_cache, result_t, lifted_leaves, stmt.args[2])
749738

750739
# Insert the undef check if necessary
@@ -761,8 +750,8 @@ function getfield_elim_pass!(ir::IRCode)
761750

762751
# global assertion_counter
763752
# assertion_counter::Int += 1
764-
#insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
765-
#continue
753+
# insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
754+
# continue
766755
compact[idx] = val === nothing ? nothing : val.x
767756
end
768757

@@ -894,7 +883,8 @@ function getfield_elim_pass!(ir::IRCode)
894883
ir[SSAValue(use)] = new_expr
895884
end
896885
end
897-
ir
886+
887+
return ir
898888
end
899889
# assertion_counter = 0
900890

@@ -935,7 +925,21 @@ end
935925
"""
936926
adce_pass!(ir::IRCode) -> newir::IRCode
937927
938-
Aggressive Dead Code Elimination pass to eliminate code.
928+
Aggressive Dead Code Elimination pass.
929+
930+
In addition to a simple DCE for unused values and allocations,
931+
this pass also nullifies `typeassert` calls that can be proved to be no-op,
932+
in order to allow LLVM to emit simpler code down the road.
933+
934+
Note that this pass is more effective after SROA optimization (i.e. `getfield_elim_pass!`),
935+
since SROA often allows this pass to:
936+
- eliminate allocation of object whose field references are all replaced with scalar values, and
937+
- nullify `typeassert` call whose first operand has been replaced with a scalar value
938+
(, which may have introduced new type information that inference did not understand)
939+
940+
Also note that currently this pass _needs_ to run after `getfield_elim_pass!`, because
941+
the `typeassert` elimination depends on the transformation within `getfield_elim_pass!`
942+
which redirects references of `typeassert`ed value to the corresponding `PiNode`.
939943
"""
940944
function adce_pass!(ir::IRCode)
941945
phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes))
@@ -944,6 +948,14 @@ function adce_pass!(ir::IRCode)
944948
for ((_, idx), stmt) in compact
945949
if isa(stmt, PhiNode)
946950
push!(all_phis, idx)
951+
elseif isexpr(stmt, :call)
952+
# nullify safe `typeassert` calls
953+
if is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3
954+
ty, isexact = instanceof_tfunc(compact_exprtype(compact, stmt.args[3]))
955+
if isexact && compact_exprtype(compact, stmt.args[2]) ty
956+
compact[idx] = nothing
957+
end
958+
end
947959
end
948960
end
949961
non_dce_finish!(compact)

test/compiler/inference.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3422,8 +3422,7 @@ let
34223422
ci.ssavaluetypes = Any[Any for i = 1:ci.ssavaluetypes]
34233423
sv = Core.Compiler.OptimizationState(mi, Core.Compiler.OptimizationParams(),
34243424
Core.Compiler.NativeInterpreter())
3425-
ir = Core.Compiler.convert_to_ircode(ci, Core.Compiler.copy_exprargs(ci.code),
3426-
false, sv)
3425+
ir = Core.Compiler.convert_to_ircode(ci, sv)
34273426
ir = Core.Compiler.slot2reg(ir, ci, sv)
34283427
ir = Core.Compiler.compact!(ir)
34293428
Core.Compiler.replace_code_newstyle!(ci, ir, 4)

test/compiler/irpasses.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,32 @@ let # `getfield_elim_pass!` should work with constant globals
425425
return Meta.isexpr(stmt, :new)
426426
end
427427
end
428+
429+
let # `typeassert_elim_pass!`
430+
src = @eval Module() begin
431+
struct Foo; x; end
432+
433+
code_typed((Int,)) do a
434+
x1 = Foo(a)
435+
x2 = Foo(x1)
436+
x3 = Foo(x2)
437+
438+
r1 = (x2.x::Foo).x
439+
r2 = (x2.x::Foo).x::Int
440+
r3 = (x2.x::Foo).x::Integer
441+
r4 = ((x3.x::Foo).x::Foo).x
442+
443+
return r1, r2, r3, r4
444+
end |> only |> first
445+
end
446+
# eliminate `typeassert(f2.a, Foo)`
447+
@test all(src.code) do @nospecialize(stmt)
448+
Meta.isexpr(stmt, :call) || return true
449+
ft = Core.Compiler.argextype(stmt.args[1], src, Any[], src.slottypes)
450+
return Core.Compiler.widenconst(ft) !== typeof(typeassert)
451+
end
452+
# succeeding simple DCE will eliminate `Foo(a)`
453+
@test all(src.code) do @nospecialize(stmt)
454+
return !Meta.isexpr(stmt, :new)
455+
end
456+
end

0 commit comments

Comments
 (0)