Skip to content

Commit d07a245

Browse files
authored
Elide bounds checks when kernels contains manual ones. (#2621)
1 parent 6ef1a3d commit d07a245

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

src/device/array.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ Base.size(g::CuDeviceArray) = g.dims
4545
Base.sizeof(x::CuDeviceArray) = Base.elsize(x) * length(x)
4646

4747
# we store the array length too; computing prod(size) is expensive
48+
Base.size(g::CuDeviceArray{<:Any,1}) = (g.len,)
4849
Base.length(g::CuDeviceArray) = g.len
4950

5051
Base.pointer(x::CuDeviceArray{T,<:Any,A}) where {T,A} = Base.unsafe_convert(LLVMPtr{T,A}, x)
@@ -78,7 +79,11 @@ Base.unsafe_convert(::Type{LLVMPtr{T,A}}, x::CuDeviceArray{T,<:Any,A}) where {T,
7879
end
7980

8081
@device_function @inline function arrayref(A::CuDeviceArray{T}, index::Integer) where {T}
81-
@boundscheck checkbounds(A, index)
82+
# simplified bounds check to avoid the OneTo construction, which calls `max`
83+
# and breaks elimination of redundant bounds checks in the generated code.
84+
#@boundscheck checkbounds(A, index)
85+
@boundscheck index <= length(A) || Base.throw_boundserror(A, index)
86+
8287
if Base.isbitsunion(T)
8388
arrayref_union(A, index)
8489
else
@@ -120,7 +125,10 @@ end
120125
end
121126

122127
@device_function @inline function arrayset(A::CuDeviceArray{T}, x::T, index::Integer) where {T}
123-
@boundscheck checkbounds(A, index)
128+
# simplified bounds check (see `arrayref`)
129+
#@boundscheck checkbounds(A, index)
130+
@boundscheck index <= length(A) || Base.throw_boundserror(A, index)
131+
124132
if Base.isbitsunion(T)
125133
arrayset_union(A, x, index)
126134
else
@@ -151,7 +159,10 @@ end
151159
end
152160

153161
@device_function @inline function const_arrayref(A::CuDeviceArray{T}, index::Integer) where {T}
154-
@boundscheck checkbounds(A, index)
162+
# simplified bounds check (see `arrayset`)
163+
#@boundscheck checkbounds(A, index)
164+
@boundscheck index <= length(A) || Base.throw_boundserror(A, index)
165+
155166
align = alignment(A)
156167
unsafe_cached_load(pointer(A), index, Val(align))
157168
end

test/base/exceptions.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ let (proc, out, err) = julia_exec(`-g2 -e $script`)
5252
@test count(device_error_re, out) == 1
5353
@test count("BoundsError", out) == 1
5454
@test count("Out-of-bounds array access", out) == 1
55-
@test occursin("] checkbounds at $(joinpath(".", "abstractarray.jl"))", out)
5655
@test occursin("] kernel at $(joinpath(".", "none"))", out)
5756
end
5857

test/core/device/array.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,24 @@ end
7979
@test !occursin("jl_invoke", ir)
8080
CUDA.code_ptx(devnull, kernel, tt)
8181
end
82+
83+
# test that we don't do needless bounds checking when the kernel already does it
84+
# (enabled by the fact that we store `len` next to `dims`)
85+
let
86+
function kernel(A)
87+
idx = threadIdx().x
88+
if idx <= length(A)
89+
# we did our own bounds checking, so no check should be left!
90+
A[idx] = 1
91+
end
92+
return
93+
end
94+
95+
for N in 1:3
96+
ir = sprint(io->CUDA.code_llvm(io, kernel, Tuple{CuDeviceArray{Int,N,AS.Global}}))
97+
@test !occursin("boundserror", ir)
98+
end
99+
end
82100
end
83101

84102
@testset "views" begin

0 commit comments

Comments
 (0)