Skip to content

Commit 976a6f2

Browse files
committed
Use Any alias for the JLArray reference implementation.
1 parent 922e538 commit 976a6f2

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

src/reference.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ Base.copy(a::JLArray{T,N}) where {T,N} = JLArray{T,N}(copy(a.data), size(a))
181181
## derived types
182182

183183
export DenseJLArray, DenseJLVector, DenseJLMatrix, DenseJLVecOrMat,
184-
StridedJLArray, StridedJLVector, StridedJLMatrix, StridedJLVecOrMat
184+
StridedJLArray, StridedJLVector, StridedJLMatrix, StridedJLVecOrMat,
185+
AnyJLArray, AnyJLVector, AnyJLMatrix, AnyJLVecOrMat
185186

186187
ContiguousSubJLArray{T,N,A<:JLArray} = Base.FastContiguousSubArray{T,N,A}
187188

@@ -208,6 +209,12 @@ StridedJLVector{T} = StridedJLArray{T,1}
208209
StridedJLMatrix{T} = StridedJLArray{T,2}
209210
StridedJLVecOrMat{T} = Union{StridedJLVector{T}, StridedJLMatrix{T}}
210211

212+
# anything that's (secretly) backed by a JLArray
213+
AnyJLArray{T,N} = Union{JLArray{T,N}, WrappedArray{T,N,JLArray,JLArray{T,N}}}
214+
AnyJLVector{T} = AnyJLArray{T,1}
215+
AnyJLMatrix{T} = AnyJLArray{T,2}
216+
AnyJLVecOrMat{T} = Union{AnyJLVector{T}, AnyJLMatrix{T}}
217+
211218

212219
## array interface
213220

@@ -259,7 +266,7 @@ struct JLArrayStyle{N} <: AbstractGPUArrayStyle{N} end
259266
JLArrayStyle(::Val{N}) where N = JLArrayStyle{N}()
260267
JLArrayStyle{M}(::Val{N}) where {N,M} = JLArrayStyle{N}()
261268

262-
BroadcastStyle(::Type{JLArray{T,N}}) where {T,N} = JLArrayStyle{N}()
269+
BroadcastStyle(::Type{<:AnyJLArray{T,N}}) where {T,N} = JLArrayStyle{N}()
263270

264271
# Allocating the output container
265272
Base.similar(bc::Broadcasted{JLArrayStyle{N}}, ::Type{T}) where {N,T} =
@@ -355,20 +362,20 @@ end
355362
using Random
356363

357364
# JLArray only supports generating random numbers with the GPUArrays RNG
358-
Random.rand!(A::JLArray) = Random.rand!(GPUArrays.default_rng(JLArray), A)
359-
Random.randn!(A::JLArray) = Random.randn!(GPUArrays.default_rng(JLArray), A)
365+
Random.rand!(A::AnyJLArray) = Random.rand!(GPUArrays.default_rng(JLArray), A)
366+
Random.randn!(A::AnyJLArray) = Random.randn!(GPUArrays.default_rng(JLArray), A)
360367

361368

362369
## GPUArrays interfaces
363370

364-
GPUArrays.device(x::JLArray) = JLDevice()
371+
GPUArrays.device(x::AnyJLArray) = JLDevice()
365372

366-
GPUArrays.backend(::Type{<:JLArray}) = JLBackend()
373+
GPUArrays.backend(::Type{<:AnyJLArray}) = JLBackend()
367374

368375
Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
369376
JLDeviceArray{T,N}(x.data, x.dims)
370377

371-
function GPUArrays.mapreducedim!(f, op, R::JLArray, A::Union{AbstractArray,Broadcast.Broadcasted};
378+
function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Broadcast.Broadcasted};
372379
init=nothing)
373380
if init !== nothing
374381
fill!(R, init)

0 commit comments

Comments
 (0)