Skip to content

Commit a7b3f75

Browse files
committed
Fix issues with critical edges
1 parent 02927c3 commit a7b3f75

File tree

3 files changed

+85
-32
lines changed

3 files changed

+85
-32
lines changed

src/stage1/compiler_utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,6 @@ function Base.iterate(it::Iterators.Reverse{BBIdxIter},
4444
end
4545
return (bb, idx - 1), (bb, idx - 1)
4646
end
47+
48+
Base.lastindex(x::Core.Compiler.InstructionStream) =
49+
Core.Compiler.length(x)

src/stage1/recurse.jl

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,12 @@ function transform!(ci, meth, nargs, sparams, N)
277277

278278
if length(cfg.blocks) != 1
279279
ϕ = PhiNode()
280+
280281
bb_start = length(ir.stmts)+1
281282
push!(ir, NewInstruction(ϕ))
282283
push!(ir, NewInstruction(ReturnNode(SSAValue(length(ir.stmts)))))
283284
push!(ir.cfg, BasicBlock(StmtRange(bb_start, length(ir.stmts))))
284-
new_bb_idx = length(cfg.blocks)
285+
new_bb_idx = length(ir.cfg.blocks)
285286

286287
for (bb, i) in bbidxiter(ir)
287288
bb == new_bb_idx && break
@@ -301,6 +302,27 @@ function transform!(ci, meth, nargs, sparams, N)
301302

302303
ir = compact!(ir)
303304
ir = split_critical_edges!(ir)
305+
306+
# If that resulted in the return not being the last block, fix that now.
307+
# We keep things simple this way, such that the basic blocks in the
308+
# forward and reverse are simply inverses of each other (i.e. the
309+
# exist block needs to be last, since the entry block needs to be first
310+
# in the reverse pass).
311+
312+
if !isa(ir.stmts[end][:inst], ReturnNode)
313+
new_bb_idx = length(ir.cfg.blocks)+1
314+
for (bb, i) in bbidxiter(ir)
315+
stmt = ir.stmts[i][:inst]
316+
if isa(stmt, ReturnNode)
317+
ir[i] = NewInstruction(GotoNode(new_bb_idx))
318+
push!(ir, NewInstruction(stmt))
319+
push!(ir.cfg, BasicBlock(StmtRange(length(ir.stmts), length(ir.stmts))))
320+
cfg_insert_edge!(ir.cfg, bb, new_bb_idx)
321+
break
322+
end
323+
end
324+
end
325+
304326
cfg = ir.cfg
305327

306328
# Now add a special control flow marker to every basic block
@@ -499,6 +521,7 @@ function transform!(ci, meth, nargs, sparams, N)
499521

500522
if isa(stmt, Core.ReturnNode)
501523
accum!(stmt.val, Argument(2))
524+
current_env = nothing
502525
elseif isexpr(stmt, :call)
503526
Δ = do_accum(SSAValue(i))
504527
callee = retrieve_ctx_obj(current_env, i)
@@ -784,38 +807,41 @@ function transform!(ci, meth, nargs, sparams, N)
784807
end
785808

786809
succs = cfg.blocks[active_bb].succs
787-
if old_idx == last(orig_bb_ranges[active_bb]) && length(succs) != 0
788-
override = false
789-
if has_terminator[active_bb]
790-
terminator = compact[idx]
791-
compact[idx] = nothing
792-
override = true
793-
end
794-
function terminator_insert_node!(node)
795-
if override
796-
compact[idx] = node.stmt
797-
override = false
798-
return SSAValue(idx)
799-
else
800-
return insert_node_here!(compact, node, true)
810+
811+
if old_idx == last(orig_bb_ranges[active_bb])
812+
if length(succs) != 0
813+
override = false
814+
if has_terminator[active_bb]
815+
terminator = compact[idx]
816+
compact[idx] = nothing
817+
override = true
801818
end
802-
end
803-
tup = terminator_insert_node!(
804-
effect_free(NewInstruction(Expr(:call, tuple, rev[orig_bb_ranges[active_bb]]...), Any, Int32(0))))
805-
for succ in succs
806-
preds = cfg.blocks[succ].preds
807-
if length(preds) == 1
808-
val = tup
809-
else
810-
selector = findfirst(==(active_bb), preds)
811-
val = insert_node_here!(compact, effect_free(NewInstruction(Expr(:call, tuple, selector, tup), Any, Int32(0))), true)
819+
function terminator_insert_node!(node)
820+
if override
821+
compact[idx] = node.stmt
822+
override = false
823+
return SSAValue(idx)
824+
else
825+
return insert_node_here!(compact, node, true)
826+
end
827+
end
828+
tup = terminator_insert_node!(
829+
effect_free(NewInstruction(Expr(:call, tuple, rev[orig_bb_ranges[active_bb]]...), Any, Int32(0))))
830+
for succ in succs
831+
preds = cfg.blocks[succ].preds
832+
if length(preds) == 1
833+
val = tup
834+
else
835+
selector = findfirst(==(active_bb), preds)
836+
val = insert_node_here!(compact, effect_free(NewInstruction(Expr(:call, tuple, selector, tup), Any, Int32(0))), true)
837+
end
838+
pn = phi_nodes[succ]
839+
push!(pn.edges, active_bb)
840+
push!(pn.values, val)
841+
end
842+
if has_terminator[active_bb]
843+
insert_node_here!(compact, NewInstruction(terminator, Any, Int32(0)), true)
812844
end
813-
pn = phi_nodes[succ]
814-
push!(pn.edges, active_bb)
815-
push!(pn.values, val)
816-
end
817-
if has_terminator[active_bb]
818-
insert_node_here!(compact, NewInstruction(terminator, Any, Int32(0)), true)
819845
end
820846
active_bb += 1
821847
end

test/runtests.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,32 @@ function f_broadcast(a)
155155
end
156156
@test fwd(f_broadcast)(1.0) == bwd(f_broadcast)(1.0)
157157

158+
# Make sure that there's no infinite recursion in kwarg calls
158159
g_kw(;x=1.0) = sin(x)
159160
f_kw(x) = g_kw(;x)
160-
bwd(f_kw)
161+
@test bwd(f_kw)(1.0) == bwd(sin)(1.0)
162+
163+
function f_crit_edge(a, b, c, x)
164+
# A function with two critical edges. This used to trigger an issue where
165+
# Diffractor would fail to insert edges for the second split critical edge.
166+
y = 1x
167+
if a && b
168+
y = 2x
169+
end
170+
if b && c
171+
y = 3x
172+
end
173+
174+
if c
175+
y = 4y
176+
end
177+
178+
return y
179+
end
180+
@test bwd(x->f_crit_edge(false, false, false, x))(1.0) == 1.0
181+
@test bwd(x->f_crit_edge(true, true, false, x))(1.0) == 2.0
182+
@test bwd(x->f_crit_edge(false, true, true, x))(1.0) == 12.0
183+
@test bwd(x->f_crit_edge(false, false, true, x))(1.0) == 4.0
184+
161185

162186
include("pinn.jl")

0 commit comments

Comments
 (0)