Skip to content

Commit ea2e9f5

Browse files
authored
Ensure that :boundscheck is properly handled (#177)
My previous fix was incorrect because `:boundscheck` returns a `Bool`. Let's just force it to `true` for now, which implicitly disables `@inbounds`, which is probably a good idea for us since we haven't gone through the effort of ensuring that all indexing is still valid. Add tests for this so that we catch this breakage faster next time.
1 parent 9e04378 commit ea2e9f5

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

src/codegen/forward.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,15 @@ function fwd_transform!(ci, mi, nargs, N)
5454
elseif isexpr(stmt, :foreigncall)
5555
return Expr(:call, error, "Attempted to AD a foreigncall. Missing rule?")
5656
elseif isexpr(stmt, :meta) || isexpr(stmt, :inbounds) || isexpr(stmt, :loopinfo) ||
57-
isexpr(stmt, :boundscheck) || isexpr(stmt, :code_coverage_effect)
57+
isexpr(stmt, :code_coverage_effect)
5858
# Can't trust that meta annotations are still valid in the AD'd
5959
# version.
6060
return nothing
61+
62+
# Always disable `@inbounds`, as we don't actually know if the AD'd
63+
# code is truly `@inbounds` or not.
64+
elseif isexpr(stmt, :boundscheck)
65+
return ZeroBundle{N}(true)
6166
else
6267
# Fallback case, for literals.
6368
# If it is an Expr, then it is not a literal

src/codegen/reverse.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
321321
# We drop gradients for globals and static parameters
322322
elseif isexpr(stmt, :inbounds)
323323
# Nothing to do
324+
elseif isexpr(stmt, :boundscheck)
325+
# TODO: do something here
324326
elseif isa(stmt, PhiNode)
325327
Δ = do_accum(SSAValue(i))
326328
@assert length(ir.cfg.blocks[bb].preds) >= 1

test/forward.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,25 @@ end
6969
@test primal_calls[] == 1
7070
end
7171

72-
end
72+
@testset "indexing" begin
73+
# Test to make sure that `:boundscheck` and such are properly handled
74+
function foo(x)
75+
t = (x, x)
76+
return t[1] + 1
77+
end
78+
79+
let var"'" = Diffractor.PrimeDerivativeFwd
80+
@test foo'(1.0) == 1.0
81+
end
82+
83+
# Test that `@inbounds` is ignored by Diffractor
84+
function foo_errors(x)
85+
t = (x, x)
86+
@inbounds return t[3] + 1
87+
end
88+
let var"'" = Diffractor.PrimeDerivativeFwd
89+
@test_throws BoundsError foo_errors'(1.0) == 1.0
90+
end
91+
end
92+
93+
end

0 commit comments

Comments
 (0)