Skip to content

Commit 7e6acc9

Browse files
authored
Elide bounds checks when kernels contains manual ones. (#486)
1 parent 517c254 commit 7e6acc9

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

src/device/array.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Base.size(g::oneDeviceArray) = g.dims
4848
Base.sizeof(x::oneDeviceArray) = Base.elsize(x) * length(x)
4949

5050
# we store the array length too; computing prod(size) is expensive
51+
Base.size(g::oneDeviceArray{<:Any, 1}) = (g.len,)
5152
Base.length(g::oneDeviceArray) = g.len
5253

5354
Base.pointer(x::oneDeviceArray{T,<:Any,A}) where {T,A} = Base.unsafe_convert(LLVMPtr{T,A}, x)
@@ -81,7 +82,11 @@ Base.unsafe_convert(::Type{LLVMPtr{T,A}}, x::oneDeviceArray{T,<:Any,A}) where {T
8182
end
8283

8384
@device_function @inline function arrayref(A::oneDeviceArray{T}, index::Integer) where {T}
84-
@boundscheck checkbounds(A, index)
85+
# simplified bounds check to avoid the OneTo construction, which calls `max`
86+
# and breaks elimination of redundant bounds checks in the generated code.
87+
#@boundscheck checkbounds(A, index)
88+
@boundscheck index <= length(A) || Base.throw_boundserror(A, index)
89+
8590
if isbitstype(T)
8691
arrayref_bits(A, index)
8792
else #if isbitsunion(T)
@@ -123,7 +128,10 @@ end
123128
end
124129

125130
@device_function @inline function arrayset(A::oneDeviceArray{T}, x::T, index::Integer) where {T}
126-
@boundscheck checkbounds(A, index)
131+
# simplified bounds check (see `arrayref`)
132+
#@boundscheck checkbounds(A, index)
133+
@boundscheck index <= length(A) || Base.throw_boundserror(A, index)
134+
127135
if isbitstype(T)
128136
arrayset_bits(A, x, index)
129137
else #if isbitsunion(T)
@@ -154,7 +162,10 @@ end
154162
end
155163

156164
@device_function @inline function const_arrayref(A::oneDeviceArray{T}, index::Integer) where {T}
157-
@boundscheck checkbounds(A, index)
165+
# simplified bounds check (see `arrayset`)
166+
#@boundscheck checkbounds(A, index)
167+
@boundscheck index <= length(A) || Base.throw_boundserror(A, index)
168+
158169
align = alignment(A)
159170
unsafe_cached_load(pointer(A), index, Val(align))
160171
end

0 commit comments

Comments
 (0)