Skip to content

Commit b05424b

Browse files
committed
wip: implement a better statement selection logic
Specifically, this commit aims to review the implementation of `add_control_flow!` and improves its accuracy. Ideally, it should pass JET's existing test cases as well as the newly added ones, including the test cases from JuliaDebug/LoweredCodeUtils.jl#99. The goal is to share the same high-precision CFG selection logic between LoweredCodeUtils and JET. The current implementation is based on [this paper](https://www.cse.msu.edu/~cse870/Public/Homework/SS2003/HW5/p439-weiser.pdf), and it has been modified to use an algorithm that checks for liveness in the reachable blocks up to the nearest common postdominator of the successors of a conditional terminator.
1 parent 2dfcd4e commit b05424b

File tree

2 files changed

+86
-41
lines changed

2 files changed

+86
-41
lines changed

src/toplevel/virtualprocess.jl

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,63 +1177,55 @@ end
11771177
# and if there is an active successor and the terminator is not a fall-through, then request
11781178
# the concretization of that terminator. Additionally, for conditional terminators, a simple
11791179
# optimization using post-domination analysis is also performed.
1180-
function add_control_flow!(concretize::BitVector, src::CodeInfo, cfg::CFG, postdomtree)
1180+
function add_control_flow!(concretize::BitVector, src::CodeInfo, cfg::CFG, domtree, postdomtree)
11811181
local changed::Bool = false
11821182
function mark_concretize!(idx::Int)
11831183
if !concretize[idx]
1184-
concretize[idx] = true
1184+
changed |= concretize[idx] = true
11851185
return true
11861186
end
11871187
return false
11881188
end
1189-
nblocks = length(cfg.blocks)
1190-
for bbidx = 1:nblocks
1191-
bb = cfg.blocks[bbidx] # forward traversal
1189+
for bbidx = 1:length(cfg.blocks) # forward traversal
1190+
bb = cfg.blocks[bbidx]
11921191
nsuccs = length(bb.succs)
11931192
if nsuccs == 0
11941193
continue
11951194
elseif nsuccs == 1
1196-
terminator_idx = bb.stmts[end]
1197-
if src.code[terminator_idx] isa GotoNode
1198-
# If the destination of this `GotoNode` is not active, it's fine to ignore
1199-
# the control flow caused by this `GotoNode` and treat it as a fall-through.
1200-
# If the block that is fallen through to is active and has a dependency on
1201-
# this goto block, then the concretization of this goto block should already
1202-
# be requested (at some point of the higher concretization convergence cycle
1203-
# of `select_dependencies`), and thus, this `GotoNode` will be concretized.
1204-
if any(@view concretize[cfg.blocks[only(bb.succs)].stmts])
1205-
changed |= mark_concretize!(terminator_idx)
1195+
termidx = bb.stmts[end]
1196+
if src.code[termidx] isa GotoNode
1197+
succ = only(bb.succs)
1198+
if any(@view concretize[cfg.blocks[succ].stmts])
1199+
dominator = CC.nearest_common_dominator(domtree, bbidx, succ)
1200+
postdominator = CC.nearest_common_dominator(domtree, bbidx, succ)
1201+
for blk in reachable_blocks(cfg, dominator, postdominator)
1202+
if blk == dominator || blk == postdominator
1203+
continue
1204+
end
1205+
if any(@view concretize[cfg.blocks[blk].stmts])
1206+
mark_concretize!(termidx)
1207+
break
1208+
end
1209+
end
12061210
end
12071211
end
1212+
continue # otherwise we can just fall-through
12081213
elseif nsuccs == 2
1209-
terminator_idx = bb.stmts[end]
1210-
@assert is_conditional_terminator(src.code[terminator_idx]) "invalid IR"
1214+
termidx = bb.stmts[end]
1215+
@assert is_conditional_terminator(src.code[termidx]) "invalid IR"
12111216
succ1, succ2 = bb.succs
1212-
succ1_req = any(@view concretize[cfg.blocks[succ1].stmts])
1213-
succ2_req = any(@view concretize[cfg.blocks[succ2].stmts])
1214-
if succ1_req
1215-
if succ2_req
1216-
changed |= mark_concretize!(terminator_idx)
1217-
else
1218-
active_bb, inactive_bb = succ1, succ2
1219-
@goto asymmetric_case
1217+
postdominator = CC.nearest_common_dominator(postdomtree, succ1, succ2)
1218+
for blk in (reachable_blocks(cfg, succ1, postdominator)
1219+
reachable_blocks(cfg, succ2, postdominator))
1220+
if blk == postdominator
1221+
continue
12201222
end
1221-
elseif succ2_req
1222-
active_bb, inactive_bb = succ2, succ1
1223-
@label asymmetric_case
1224-
# We can ignore the control flow of this conditional terminator and treat
1225-
# it as a fall-through if only one of its successors is active and the
1226-
# active block post-dominates the inactive one, since the post-domination
1227-
# ensures that the active basic block will be reached regardless of the
1228-
# control flow.
1229-
if CC.postdominates(postdomtree, active_bb, inactive_bb)
1230-
# fall through this block
1231-
else
1232-
changed |= mark_concretize!(terminator_idx)
1223+
if any(@view concretize[cfg.blocks[blk].stmts])
1224+
mark_concretize!(termidx)
1225+
break
12331226
end
1234-
else
1235-
# both successors are inactive, just fall through this block
12361227
end
1228+
# we can just fall-through to the post dominator block (by ignoring all statements between)
12371229
end
12381230
end
12391231
return changed
@@ -1242,6 +1234,25 @@ end
12421234
is_conditional_terminator(@nospecialize stmt) = stmt isa GotoIfNot ||
12431235
(@static @isdefined(EnterNode) ? stmt isa EnterNode : isexpr(stmt, :enter))
12441236

1237+
function reachable_blocks(cfg::CFG, from_bb::Int, to_bb::Int)
1238+
worklist = Int[from_bb]
1239+
visited = BitSet(from_bb)
1240+
if to_bb == from_bb
1241+
return visited
1242+
end
1243+
push!(visited, to_bb)
1244+
function visit!(bb::Int)
1245+
if bb visited
1246+
push!(visited, bb)
1247+
push!(worklist, bb)
1248+
end
1249+
end
1250+
while !isempty(worklist)
1251+
foreach(visit!, cfg.blocks[pop!(worklist)].succs)
1252+
end
1253+
return visited
1254+
end
1255+
12451256
function add_required_inplace!(concretize::BitVector, src::CodeInfo, edges, cl)
12461257
changed = false
12471258
for i = 1:length(src.code)
@@ -1275,6 +1286,7 @@ end
12751286
function select_dependencies!(concretize::BitVector, src::CodeInfo, edges, cl)
12761287
typedefs = LoweredCodeUtils.find_typedefs(src)
12771288
cfg = CC.compute_basic_blocks(src.code)
1289+
domtree = CC.construct_domtree(cfg.blocks)
12781290
postdomtree = CC.construct_postdomtree(cfg.blocks)
12791291

12801292
while true
@@ -1292,7 +1304,7 @@ function select_dependencies!(concretize::BitVector, src::CodeInfo, edges, cl)
12921304

12931305
# mark necessary control flows,
12941306
# and propagate the definition requirements by tracking SSA precedessors
1295-
changed |= add_control_flow!(concretize, src, cfg, postdomtree)
1307+
changed |= add_control_flow!(concretize, src, cfg, domtree, postdomtree)
12961308
changed |= add_ssa_preds!(concretize, src, edges, ())
12971309

12981310
changed || break

test/toplevel/test_virtualprocess.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1748,7 +1748,7 @@ end
17481748
found_write = true
17491749
@test !slice[i]
17501750
elseif (JET.isexpr(stmt, :call) && (arg1 = stmt.args[1]; arg1 isa Core.SSAValue) &&
1751-
src.code[arg1.id] === :write)
1751+
src.code[arg1.id] === :write)
17521752
found_write = true
17531753
@test !slice[i]
17541754
end
@@ -1778,6 +1778,39 @@ end
17781778
@test isempty(s)
17791779
end
17801780

1781+
# more complex case
1782+
let src = @src begin
1783+
x2 = 5
1784+
a2 = 1
1785+
@eval global geta2() = $a2 # concretization is forced
1786+
end
1787+
slice = JET.select_statements(@__MODULE__, src)
1788+
1789+
found_a2 = found_a2_get_binding_type = found_x2 = found_x2_get_binding_type = false
1790+
for (i, stmt) in enumerate(src.code)
1791+
if JET.isexpr(stmt, :(=))
1792+
lhs, rhs = stmt.args
1793+
if lhs isa GlobalRef
1794+
lhs = lhs.name
1795+
end
1796+
if lhs === :a2
1797+
found_a2 = true
1798+
@test slice[i]
1799+
elseif lhs === :x2
1800+
found_x2 = true
1801+
@test !slice[i] # this is easy to meet
1802+
end
1803+
elseif JET.@capture(stmt, $(GlobalRef(Core, :get_binding_type))(_, :a2))
1804+
found_a2_get_binding_type = true
1805+
@test slice[i]
1806+
elseif JET.@capture(stmt, $(GlobalRef(Core, :get_binding_type))(_, :x2))
1807+
found_x2_get_binding_type = true
1808+
@test !slice[i] # this is difficult to meet
1809+
end
1810+
end
1811+
@test found_a2; @test found_a2_get_binding_type; @test found_x2; @test found_x2_get_binding_type
1812+
end
1813+
17811814
@testset "captured variables" begin
17821815
let (vmod, res) = @analyze_toplevel2 begin
17831816
begin

0 commit comments

Comments
 (0)