@@ -45,6 +45,7 @@ Base.size(g::CuDeviceArray) = g.dims
4545Base. sizeof (x:: CuDeviceArray ) = Base. elsize (x) * length (x)
4646
4747# we store the array length too; computing prod(size) is expensive
48+ Base. size (g:: CuDeviceArray{<:Any,1} ) = (g. len,)
4849Base. length (g:: CuDeviceArray ) = g. len
4950
5051Base. pointer (x:: CuDeviceArray{T,<:Any,A} ) where {T,A} = Base. unsafe_convert (LLVMPtr{T,A}, x)
@@ -78,7 +79,11 @@ Base.unsafe_convert(::Type{LLVMPtr{T,A}}, x::CuDeviceArray{T,<:Any,A}) where {T,
7879end
7980
8081@device_function @inline function arrayref (A:: CuDeviceArray{T} , index:: Integer ) where {T}
81- @boundscheck checkbounds (A, index)
82+ # simplified bounds check to avoid the OneTo construction, which calls `max`
83+ # and breaks elimination of redundant bounds checks in the generated code.
84+ # @boundscheck checkbounds(A, index)
85+ @boundscheck index <= length (A) || Base. throw_boundserror (A, index)
86+
8287 if Base. isbitsunion (T)
8388 arrayref_union (A, index)
8489 else
120125end
121126
122127@device_function @inline function arrayset (A:: CuDeviceArray{T} , x:: T , index:: Integer ) where {T}
123- @boundscheck checkbounds (A, index)
128+ # simplified bounds check (see `arrayref`)
129+ # @boundscheck checkbounds(A, index)
130+ @boundscheck index <= length (A) || Base. throw_boundserror (A, index)
131+
124132 if Base. isbitsunion (T)
125133 arrayset_union (A, x, index)
126134 else
151159end
152160
153161@device_function @inline function const_arrayref (A:: CuDeviceArray{T} , index:: Integer ) where {T}
154- @boundscheck checkbounds (A, index)
162+ # simplified bounds check (see `arrayset`)
163+ # @boundscheck checkbounds(A, index)
164+ @boundscheck index <= length (A) || Base. throw_boundserror (A, index)
165+
155166 align = alignment (A)
156167 unsafe_cached_load (pointer (A), index, Val (align))
157168end
0 commit comments