Skip to content

Commit b152846

Browse files
authored
Merge pull request #1465 from Pangoraw/unreachable_block
Handle unreachable blocks in the adjoint CFG
2 parents cf7f7d0 + bcf996a commit b152846

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
3838
ForwardDiff = "0.10"
3939
GPUArrays = "8.4.2, 9"
4040
GPUArraysCore = "0.1.1"
41-
IRTools = "0.4.4"
41+
IRTools = "0.4.11"
4242
LogExpFunctions = "0.3.1"
4343
MacroTools = "0.5"
4444
NaNMath = "0.3, 1"

src/compiler/reverse.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,6 @@ Variable(a::Alpha) = Variable(a.id)
244244
sig(b::IRTools.Block) = unique([arg for br in branches(b) for arg in br.args if arg isa Variable])
245245
sig(pr::Primal) = Dict(b.id => sig(b) for b in blocks(pr.ir))
246246

247-
# TODO unreachables?
248247
function adjointcfg(pr::Primal)
249248
ir = empty(pr.ir)
250249
return!(ir, nothing)
@@ -257,7 +256,9 @@ function adjointcfg(pr::Primal)
257256
push!(rb, xcall(Base, :(!==), alpha(pr.branches[b.id]), BranchNumber(i)))
258257
branch!(rb, preds[i].id, unless = cond)
259258
end
260-
if !isempty(branches(b)) && branches(b)[end] == IRTools.unreachable
259+
if isempty(preds) || (!isempty(branches(b)) && branches(b)[end] == IRTools.unreachable)
260+
# If `b` is unreachable, then no context produced by the primal should end up branching to `rb`
261+
push!(rb, xcall(Core, :throw, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable`
261262
branch!(rb, 0)
262263
end
263264
end

test/compiler.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,23 @@ end
225225

226226
# issue 897
227227
@test gradient(x -> sum(norm, collect(eachcol(x))), ones(3, 400))[1] fill(0.5773502691896258, 3, 400)
228+
229+
# issue 1118 & 1380
230+
function f_1380(x)
231+
if rand(Bool)
232+
return x
233+
else
234+
return 2x
235+
end
236+
237+
# unreachable
238+
return nothing
239+
end
240+
241+
@testset "unreachable block" begin
242+
y, back = Zygote.pullback(f_1380, 1.)
243+
# There should not be a compiler error
244+
local g
245+
@test_nowarn g = back(1.)
246+
@test only(g) (1., 2.)
247+
end

0 commit comments

Comments
 (0)