Skip to content

Commit 427824e

Browse files
authored
refactoring on SROA passes (#55262)
All changes are cosmetic and do not change the basic functionality: - Added the interface type to the callbacks received by `simple_walker` to clarify which objects are passed as callbacks to `simple_walker`. - Replaced ambiguous names like `idx` with more descriptive ones like `defidx` to make the algorithm easier to understand.
1 parent 4dfce5d commit 427824e

File tree

1 file changed

+88
-69
lines changed

1 file changed

+88
-69
lines changed

base/compiler/ssair/passes.jl

Lines changed: 88 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ end
7979
function find_curblock(domtree::DomTree, allblocks::BitSet, curblock::Int)
8080
# TODO: This can be much faster by looking at current level and only
8181
# searching for those blocks in a sorted order
82-
while !(curblock in allblocks) && curblock !== 0
82+
while curblock allblocks && curblock 0
8383
curblock = domtree.idoms_bb[curblock]
8484
end
8585
return curblock
@@ -190,18 +190,21 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
190190
return walk_to_defs(compact, val, typeconstraint, predecessors, 𝕃ₒ)
191191
end
192192

193-
function trivial_walker(@nospecialize(pi), @nospecialize(idx))
194-
return nothing
195-
end
193+
abstract type WalkerCallback end
196194

197-
function pi_walker(@nospecialize(pi), @nospecialize(idx))
198-
if isa(pi, PiNode)
199-
return LiftedValue(pi.val)
195+
struct TrivialWalker <: WalkerCallback end
196+
(::TrivialWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue)) = nothing
197+
198+
struct PiWalker <: WalkerCallback end
199+
function (::PiWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
200+
if isa(def, PiNode)
201+
return LiftedValue(def.val)
200202
end
201203
return nothing
202204
end
203205

204-
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), callback=trivial_walker)
206+
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa::AnySSAValue),
207+
walker_callback::WalkerCallback=TrivialWalker())
205208
while true
206209
if isa(defssa, OldSSAValue)
207210
if already_inserted(compact, defssa)
@@ -218,15 +221,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
218221
end
219222
def = compact[defssa][:stmt]
220223
if isa(def, AnySSAValue)
221-
callback(def, defssa)
224+
walker_callback(def, defssa)
222225
if isa(def, SSAValue)
223226
is_old(compact, defssa) && (def = OldSSAValue(def.id))
224227
end
225228
defssa = def
226229
elseif isa(def, Union{PhiNode, PhiCNode, GlobalRef})
227230
return defssa
228231
else
229-
new_def = callback(def, defssa)
232+
new_def = walker_callback(def, defssa)
230233
if new_def === nothing
231234
return defssa
232235
end
@@ -241,16 +244,21 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
241244
end
242245
end
243246

244-
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
245-
@nospecialize(typeconstraint))
246-
callback = function (@nospecialize(pi), @nospecialize(idx))
247-
if isa(pi, PiNode)
248-
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
249-
return LiftedValue(pi.val)
250-
end
251-
return nothing
247+
mutable struct TypeConstrainingWalker <: WalkerCallback
248+
typeconstraint::Any
249+
TypeConstrainingWalker(@nospecialize(typeconstraint::Any)) = new(typeconstraint)
250+
end
251+
function (walker_callback::TypeConstrainingWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
252+
if isa(def, PiNode)
253+
walker_callback.typeconstraint =
254+
typeintersect(walker_callback.typeconstraint, widenconst(def.typ))
255+
return LiftedValue(def.val)
252256
end
253-
def = simple_walk(compact, defssa, callback)
257+
return nothing
258+
end
259+
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(val::AnySSAValue),
260+
@nospecialize(typeconstraint))
261+
def = simple_walk(compact, val, TypeConstrainingWalker(typeconstraint))
254262
return Pair{Any, Any}(def, typeconstraint)
255263
end
256264

@@ -638,15 +646,17 @@ end
638646

639647
struct SkipToken end; const SKIP_TOKEN = SkipToken()
640648

641-
function lifted_value(compact::IncrementalCompact, @nospecialize(old_node_ssa#=::AnySSAValue=#), @nospecialize(old_value),
642-
lifted_philikes::Vector{LiftedPhilike}, lifted_leaves::Union{LiftedLeaves, LiftedDefs}, reverse_mapping::IdDict{AnySSAValue, Int},
643-
walker_callback)
649+
function lifted_value(compact::IncrementalCompact, @nospecialize(old_node_ssa::AnySSAValue),
650+
@nospecialize(old_value), lifted_philikes::Vector{LiftedPhilike},
651+
lifted_leaves::Union{LiftedLeaves, LiftedDefs},
652+
reverse_mapping::IdDict{AnySSAValue, Int},
653+
walker_callback::WalkerCallback)
644654
val = old_value
645655
if is_old(compact, old_node_ssa) && isa(val, SSAValue)
646656
val = OldSSAValue(val.id)
647657
end
648658
if isa(val, AnySSAValue)
649-
val = simple_walk(compact, val, def_walker(lifted_leaves, reverse_mapping, walker_callback))
659+
val = simple_walk(compact, val, LiftedLeaveWalker(lifted_leaves, reverse_mapping, walker_callback))
650660
end
651661
if val in keys(lifted_leaves)
652662
lifted_val = lifted_leaves[val]
@@ -656,7 +666,7 @@ function lifted_value(compact::IncrementalCompact, @nospecialize(old_node_ssa#=:
656666
lifted_val === nothing && return UNDEF_TOKEN
657667
val = lifted_val.val
658668
if isa(val, AnySSAValue)
659-
val = simple_walk(compact, val, pi_walker)
669+
val = simple_walk(compact, val, PiWalker())
660670
end
661671
return val
662672
elseif isa(val, AnySSAValue) && val in keys(reverse_mapping)
@@ -673,7 +683,7 @@ function is_old(compact, @nospecialize(old_node_ssa))
673683
return true
674684
end
675685

676-
struct PhiNest{C}
686+
struct PhiNest{C<:WalkerCallback}
677687
visited_philikes::Vector{AnySSAValue}
678688
lifted_philikes::Vector{LiftedPhilike}
679689
lifted_leaves::Union{LiftedLeaves, LiftedDefs}
@@ -743,20 +753,29 @@ function finish_phi_nest!(compact::IncrementalCompact, nest::PhiNest)
743753
end
744754
end
745755

746-
function def_walker(lifted_leaves::Union{LiftedLeaves, LiftedDefs}, reverse_mapping::IdDict{AnySSAValue, Int}, walker_callback)
747-
function (@nospecialize(walk_def), @nospecialize(defssa))
748-
if (defssa in keys(lifted_leaves)) || (isa(defssa, AnySSAValue) && defssa in keys(reverse_mapping))
749-
return nothing
750-
end
751-
isa(walk_def, PiNode) && return LiftedValue(walk_def.val)
752-
return walker_callback(walk_def, defssa)
756+
struct LiftedLeaveWalker{C<:WalkerCallback} <: WalkerCallback
757+
lifted_leaves::Union{LiftedLeaves, LiftedDefs}
758+
reverse_mapping::IdDict{AnySSAValue, Int}
759+
inner_walker_callback::C
760+
function LiftedLeaveWalker(@nospecialize(lifted_leaves::Union{LiftedLeaves, LiftedDefs}),
761+
@nospecialize(reverse_mapping::IdDict{AnySSAValue, Int}),
762+
inner_walker_callback::C) where C<:WalkerCallback
763+
return new{C}(lifted_leaves, reverse_mapping, inner_walker_callback)
753764
end
754765
end
766+
function (walker_callback::LiftedLeaveWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
767+
(; lifted_leaves, reverse_mapping, inner_walker_callback) = walker_callback
768+
if defssa in keys(lifted_leaves) || defssa in keys(reverse_mapping)
769+
return nothing
770+
end
771+
isa(def, PiNode) && return LiftedValue(def.val)
772+
return inner_walker_callback(def, defssa)
773+
end
755774

756775
function perform_lifting!(compact::IncrementalCompact,
757776
visited_philikes::Vector{AnySSAValue}, @nospecialize(cache_key),
758777
@nospecialize(result_t), lifted_leaves::Union{LiftedLeaves, LiftedDefs}, @nospecialize(stmt_val),
759-
lazydomtree::Union{LazyDomtree,Nothing}, walker_callback = trivial_walker)
778+
lazydomtree::Union{LazyDomtree,Nothing}, walker_callback::WalkerCallback = TrivialWalker())
760779
reverse_mapping = IdDict{AnySSAValue, Int}()
761780
for id in 1:length(visited_philikes)
762781
reverse_mapping[visited_philikes[id]] = id
@@ -839,7 +858,7 @@ function perform_lifting!(compact::IncrementalCompact,
839858

840859
# Fixup the stmt itself
841860
if isa(stmt_val, Union{SSAValue, OldSSAValue})
842-
stmt_val = simple_walk(compact, stmt_val, def_walker(lifted_leaves, reverse_mapping, walker_callback))
861+
stmt_val = simple_walk(compact, stmt_val, LiftedLeaveWalker(lifted_leaves, reverse_mapping, walker_callback))
843862
end
844863

845864
if stmt_val in keys(lifted_leaves)
@@ -948,6 +967,17 @@ function keyvalue_predecessors(@nospecialize(key), 𝕃ₒ::AbstractLattice)
948967
end
949968
end
950969

970+
struct KeyValueWalker <: WalkerCallback
971+
compact::IncrementalCompact
972+
end
973+
function (walker_callback::KeyValueWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
974+
if is_known_invoke_or_call(def, Core.OptimizedGenerics.KeyValue.set, walker_callback.compact)
975+
@assert length(def.args) in (5, 6)
976+
return LiftedValue(def.args[end-2])
977+
end
978+
return nothing
979+
end
980+
951981
function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr, 𝕃ₒ::AbstractLattice)
952982
collection = stmt.args[end-1]
953983
key = stmt.args[end]
@@ -964,16 +994,9 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
964994
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
965995
end
966996

967-
function keyvalue_walker(@nospecialize(def), _)
968-
if is_known_invoke_or_call(def, Core.OptimizedGenerics.KeyValue.set, compact)
969-
@assert length(def.args) in (5, 6)
970-
return LiftedValue(def.args[end-2])
971-
end
972-
return nothing
973-
end
974997
(lifted_val, nest) = perform_lifting!(compact,
975998
visited_philikes, key, result_t, lifted_leaves, collection, nothing,
976-
keyvalue_walker)
999+
KeyValueWalker(compact))
9771000

9781001
compact[idx] = lifted_val === nothing ? nothing : Expr(:call, GlobalRef(Core, :tuple), lifted_val.val)
9791002
finish_phi_nest!(compact, nest)
@@ -1139,13 +1162,11 @@ end
11391162
# which can be very large sometimes, and program counters in question are often very sparse
11401163
const SPCSet = IdSet{Int}
11411164

1142-
struct IntermediaryCollector
1165+
struct IntermediaryCollector <: WalkerCallback
11431166
intermediaries::SPCSet
11441167
end
1145-
function (this::IntermediaryCollector)(@nospecialize(pi), @nospecialize(ssa))
1146-
if !isa(pi, Expr)
1147-
push!(this.intermediaries, ssa.id)
1148-
end
1168+
function (walker_callback::IntermediaryCollector)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
1169+
isa(def, Expr) || push!(walker_callback.intermediaries, defssa.id)
11491170
return nothing
11501171
end
11511172

@@ -1242,7 +1263,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
12421263
update_scope_mapping!(scope_mapping, bb+1, bbs)
12431264
end
12441265
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
1245-
is_setfield = is_isdefined = is_finalizer = is_keyvalue_get = false
1266+
is_setfield = is_isdefined = is_finalizer = false
12461267
field_ordering = :unspecified
12471268
if is_known_call(stmt, setfield!, compact)
12481269
4 <= length(stmt.args) <= 5 || continue
@@ -1371,8 +1392,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
13711392
if ismutabletypename(struct_typ_name)
13721393
isa(val, SSAValue) || continue
13731394
let intermediaries = SPCSet()
1374-
callback = IntermediaryCollector(intermediaries)
1375-
def = simple_walk(compact, val, callback)
1395+
def = simple_walk(compact, val, IntermediaryCollector(intermediaries))
13761396
# Mutable stuff here
13771397
isa(def, SSAValue) || continue
13781398
if defuses === nothing
@@ -1680,24 +1700,23 @@ end
16801700
function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, lazydomtree::LazyDomtree, inlining::Union{Nothing, InliningState})
16811701
𝕃ₒ = inlining === nothing ? SimpleInferenceLattice.instance : optimizer_lattice(inlining.interp)
16821702
lazypostdomtree = LazyPostDomtree(ir)
1683-
for (idx, (intermediaries, defuse)) in defuses
1703+
for (defidx, (intermediaries, defuse)) in defuses
16841704
intermediaries = collect(intermediaries)
16851705
# Check if there are any uses we did not account for. If so, the variable
16861706
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
16871707
# not to include any intermediaries that have dead uses. As a result, missing uses will only ever
16881708
# show up in the nuses_total count.
16891709
nleaves = length(defuse.uses) + length(defuse.defs)
16901710
nuses = 0
1691-
for idx in intermediaries
1692-
nuses += used_ssas[idx]
1711+
for iidx in intermediaries
1712+
nuses += used_ssas[iidx]
16931713
end
1694-
nuses_total = used_ssas[idx] + nuses - length(intermediaries)
1714+
nuses_total = used_ssas[defidx] + nuses - length(intermediaries)
16951715
nleaves == nuses_total || continue
16961716
# Find the type for this allocation
1697-
defexpr = ir[SSAValue(idx)][:stmt]
1717+
defexpr = ir[SSAValue(defidx)][:stmt]
16981718
isexpr(defexpr, :new) || continue
1699-
newidx = idx
1700-
typ = unwrap_unionall(ir.stmts[newidx][:type])
1719+
typ = unwrap_unionall(ir.stmts[defidx][:type])
17011720
# Could still end up here if we tried to setfield! on an immutable, which would
17021721
# error at runtime, but is not illegal to have in the IR.
17031722
typ = widenconst(typ)
@@ -1713,7 +1732,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
17131732
end
17141733
end
17151734
if finalizer_idx !== nothing && inlining !== nothing
1716-
try_resolve_finalizer!(ir, idx, finalizer_idx, defuse, inlining,
1735+
try_resolve_finalizer!(ir, defidx, finalizer_idx, defuse, inlining,
17171736
lazydomtree, lazypostdomtree, ir[SSAValue(finalizer_idx)][:info])
17181737
continue
17191738
end
@@ -1752,11 +1771,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
17521771
# but we should come up with semantics for well defined semantics
17531772
# for uninitialized fields first.
17541773
ndefuse = length(fielddefuse)
1755-
blocks = Vector{Tuple{#=phiblocks=# Vector{Int}, #=allblocks=# BitSet}}(undef, ndefuse)
1774+
blocks = Vector{Tuple{#=phiblocks=#Vector{Int},#=allblocks=#BitSet}}(undef, ndefuse)
17561775
for fidx in 1:ndefuse
17571776
du = fielddefuse[fidx]
17581777
isempty(du.uses) && continue
1759-
push!(du.defs, newidx)
1778+
push!(du.defs, defidx)
17601779
ldu = compute_live_ins(ir.cfg, du)
17611780
if isempty(ldu.live_in_bbs)
17621781
phiblocks = Int[]
@@ -1769,7 +1788,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
17691788
for i = 1:length(du.uses)
17701789
use = du.uses[i]
17711790
if use.kind === :isdefined
1772-
if has_safe_def(ir, get!(lazydomtree), allblocks, du, newidx, use.idx)
1791+
if has_safe_def(ir, get!(lazydomtree), allblocks, du, defidx, use.idx)
17731792
ir[SSAValue(use.idx)][:stmt] = true
17741793
else
17751794
all_eliminated = false
@@ -1782,7 +1801,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
17821801
continue
17831802
end
17841803
end
1785-
has_safe_def(ir, get!(lazydomtree), allblocks, du, newidx, use.idx) || @goto skip
1804+
has_safe_def(ir, get!(lazydomtree), allblocks, du, defidx, use.idx) || @goto skip
17861805
end
17871806
else # always have some definition at the allocation site
17881807
for i = 1:length(du.uses)
@@ -1849,19 +1868,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
18491868
# all "usages" (i.e. `getfield` and `isdefined` calls) are eliminated,
18501869
# now eliminate "definitions" (i.e. `setfield!`) calls
18511870
# (NOTE the allocation itself will be eliminated by DCE pass later)
1852-
for idx in du.defs
1853-
idx == newidx && continue # this is allocation
1871+
for didx in du.defs
1872+
didx == defidx && continue # this is allocation
18541873
# verify this statement won't throw, otherwise it can't be eliminated safely
1855-
ssa = SSAValue(idx)
1856-
if is_nothrow(ir, ssa)
1857-
ir[ssa][:stmt] = nothing
1874+
setfield_ssa = SSAValue(didx)
1875+
if is_nothrow(ir, setfield_ssa)
1876+
ir[setfield_ssa][:stmt] = nothing
18581877
else
18591878
# We can't eliminate this statement, because it might still
18601879
# throw an error, but we can mark it as effect-free since we
18611880
# know we have removed all uses of the mutable allocation.
18621881
# As a result, if we ever do prove nothrow, we can delete
18631882
# this statement then.
1864-
add_flag!(ir[ssa], IR_FLAG_EFFECT_FREE)
1883+
add_flag!(ir[setfield_ssa], IR_FLAG_EFFECT_FREE)
18651884
end
18661885
end
18671886
end
@@ -1870,7 +1889,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
18701889
# this means all ccall preserves have been replaced with forwarded loads
18711890
# so we can potentially eliminate the allocation, otherwise we must preserve
18721891
# the whole allocation.
1873-
push!(intermediaries, newidx)
1892+
push!(intermediaries, defidx)
18741893
end
18751894
# Insert the new preserves
18761895
for (useidx, new_preserves) in preserve_uses

0 commit comments

Comments
 (0)