@@ -181,7 +181,8 @@ Base.copy(a::JLArray{T,N}) where {T,N} = JLArray{T,N}(copy(a.data), size(a))
181
181
# # derived types
182
182
183
183
export DenseJLArray, DenseJLVector, DenseJLMatrix, DenseJLVecOrMat,
184
- StridedJLArray, StridedJLVector, StridedJLMatrix, StridedJLVecOrMat
184
+ StridedJLArray, StridedJLVector, StridedJLMatrix, StridedJLVecOrMat,
185
+ AnyJLArray, AnyJLVector, AnyJLMatrix, AnyJLVecOrMat
185
186
186
187
ContiguousSubJLArray{T,N,A<: JLArray } = Base. FastContiguousSubArray{T,N,A}
187
188
@@ -208,6 +209,12 @@ StridedJLVector{T} = StridedJLArray{T,1}
208
209
StridedJLMatrix{T} = StridedJLArray{T,2 }
209
210
StridedJLVecOrMat{T} = Union{StridedJLVector{T}, StridedJLMatrix{T}}
210
211
212
+ # anything that's (secretly) backed by a JLArray
213
+ AnyJLArray{T,N} = Union{JLArray{T,N}, WrappedArray{T,N,JLArray,JLArray{T,N}}}
214
+ AnyJLVector{T} = AnyJLArray{T,1 }
215
+ AnyJLMatrix{T} = AnyJLArray{T,2 }
216
+ AnyJLVecOrMat{T} = Union{AnyJLVector{T}, AnyJLMatrix{T}}
217
+
211
218
212
219
# # array interface
213
220
@@ -259,7 +266,7 @@ struct JLArrayStyle{N} <: AbstractGPUArrayStyle{N} end
259
266
JLArrayStyle (:: Val{N} ) where N = JLArrayStyle {N} ()
260
267
JLArrayStyle {M} (:: Val{N} ) where {N,M} = JLArrayStyle {N} ()
261
268
262
- BroadcastStyle (:: Type{JLArray {T,N}} ) where {T,N} = JLArrayStyle {N} ()
269
+ BroadcastStyle (:: Type{<:AnyJLArray {T,N}} ) where {T,N} = JLArrayStyle {N} ()
263
270
264
271
# Allocating the output container
265
272
Base. similar (bc:: Broadcasted{JLArrayStyle{N}} , :: Type{T} ) where {N,T} =
@@ -355,20 +362,20 @@ end
355
362
using Random
356
363
357
364
# JLArray only supports generating random numbers with the GPUArrays RNG
358
- Random. rand! (A:: JLArray ) = Random. rand! (GPUArrays. default_rng (JLArray), A)
359
- Random. randn! (A:: JLArray ) = Random. randn! (GPUArrays. default_rng (JLArray), A)
365
+ Random. rand! (A:: AnyJLArray ) = Random. rand! (GPUArrays. default_rng (JLArray), A)
366
+ Random. randn! (A:: AnyJLArray ) = Random. randn! (GPUArrays. default_rng (JLArray), A)
360
367
361
368
362
369
# # GPUArrays interfaces
363
370
364
- GPUArrays. device (x:: JLArray ) = JLDevice ()
371
+ GPUArrays. device (x:: AnyJLArray ) = JLDevice ()
365
372
366
- GPUArrays. backend (:: Type{<:JLArray } ) = JLBackend ()
373
+ GPUArrays. backend (:: Type{<:AnyJLArray } ) = JLBackend ()
367
374
368
375
Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N} =
369
376
JLDeviceArray {T,N} (x. data, x. dims)
370
377
371
- function GPUArrays. mapreducedim! (f, op, R:: JLArray , A:: Union{AbstractArray,Broadcast.Broadcasted} ;
378
+ function GPUArrays. mapreducedim! (f, op, R:: AnyJLArray , A:: Union{AbstractArray,Broadcast.Broadcasted} ;
372
379
init= nothing )
373
380
if init != = nothing
374
381
fill! (R, init)
0 commit comments