@@ -23,9 +23,7 @@ backend(::Type{<:AbstractGPUDevice}) = error("Not implemented") # COV_EXCL_LINE
23
23
24
24
const WrappedGPUArray{T,N} = WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{T,N}}
25
25
26
- const AbstractOrWrappedGPUArray{T,N} =
27
- Union{AbstractGPUArray{T,N},
28
- WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{T,N}}}
26
+ const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}}
29
27
30
28
31
29
# input/output
@@ -51,21 +49,21 @@ convert_to_cpu(xs) = adapt(Array, xs)
51
49
# # showing
52
50
53
51
# display
54
- Base. print_array (io:: IO , X:: AbstractOrWrappedGPUArray ) =
52
+ Base. print_array (io:: IO , X:: AnyGPUArray ) =
55
53
Base. print_array (io, convert_to_cpu (X))
56
54
57
55
# show
58
- Base. _show_nonempty (io:: IO , X:: AbstractOrWrappedGPUArray , prefix:: String ) =
56
+ Base. _show_nonempty (io:: IO , X:: AnyGPUArray , prefix:: String ) =
59
57
Base. _show_nonempty (io, convert_to_cpu (X), prefix)
60
- Base. _show_empty (io:: IO , X:: AbstractOrWrappedGPUArray ) =
58
+ Base. _show_empty (io:: IO , X:: AnyGPUArray ) =
61
59
Base. _show_empty (io, convert_to_cpu (X))
62
- Base. show_vector (io:: IO , v:: AbstractOrWrappedGPUArray , args... ) =
60
+ Base. show_vector (io:: IO , v:: AnyGPUArray , args... ) =
63
61
Base. show_vector (io, convert_to_cpu (v), args... )
64
62
65
63
# # collect to CPU (discarding wrapper type)
66
64
67
65
collect_to_cpu (xs:: AbstractArray ) = collect (convert_to_cpu (xs))
68
- Base. collect (X:: AbstractOrWrappedGPUArray ) = collect_to_cpu (X)
66
+ Base. collect (X:: AnyGPUArray ) = collect_to_cpu (X)
69
67
70
68
71
69
# memory copying
@@ -75,9 +73,9 @@ Base.collect(X::AbstractOrWrappedGPUArray) = collect_to_cpu(X)
75
73
# expects the GPU array type to have linear `copyto!` methods (i.e. accepting an integer
76
74
# offset and length) from and to CPU arrays and between GPU arrays.
77
75
78
- for (D, S) in ((AbstractOrWrappedGPUArray , Array),
79
- (Array, AbstractOrWrappedGPUArray ),
80
- (AbstractOrWrappedGPUArray, AbstractOrWrappedGPUArray ))
76
+ for (D, S) in ((AnyGPUArray , Array),
77
+ (Array, AnyGPUArray ),
78
+ (AnyGPUArray, AnyGPUArray ))
81
79
@eval begin
82
80
function Base. copyto! (dest:: $D{<:Any, N} , rdest:: UnitRange ,
83
81
src:: $S{<:Any, N} , ssrc:: UnitRange ) where {N}
@@ -112,8 +110,8 @@ function linear_copy_kernel!(ctx::AbstractKernelContext, dest, dstart, src, ssta
112
110
return
113
111
end
114
112
115
- function Base. copyto! (dest:: AbstractOrWrappedGPUArray , dstart:: Integer ,
116
- src:: AbstractOrWrappedGPUArray , sstart:: Integer , n:: Integer )
113
+ function Base. copyto! (dest:: AnyGPUArray , dstart:: Integer ,
114
+ src:: AnyGPUArray , sstart:: Integer , n:: Integer )
117
115
n == 0 && return dest
118
116
n < 0 && throw (ArgumentError (string (" tried to copy n=" , n, " elements, but n should be nonnegative" )))
119
117
destinds, srcinds = LinearIndices (dest), LinearIndices (src)
152
150
# to quickly perform these very lightweight conversions
153
151
154
152
function Base. copyto! (dest:: Array{T} , dstart:: Integer ,
155
- src:: AbstractOrWrappedGPUArray {U} , sstart:: Integer ,
153
+ src:: AnyGPUArray {U} , sstart:: Integer ,
156
154
n:: Integer ) where {T,U}
157
155
n == 0 && return dest
158
156
temp = Vector {U} (undef, n)
@@ -161,7 +159,7 @@ function Base.copyto!(dest::Array{T}, dstart::Integer,
161
159
return dest
162
160
end
163
161
164
- function Base. copyto! (dest:: AbstractOrWrappedGPUArray {T} , dstart:: Integer ,
162
+ function Base. copyto! (dest:: AnyGPUArray {T} , dstart:: Integer ,
165
163
src:: Array{U} , sstart:: Integer , n:: Integer ) where {T,U}
166
164
n == 0 && return dest
167
165
temp = Vector {T} (undef, n)
@@ -181,8 +179,8 @@ function cartesian_copy_kernel!(ctx::AbstractKernelContext, dest, dest_offsets,
181
179
return
182
180
end
183
181
184
- function Base. copyto! (dest:: AbstractOrWrappedGPUArray {<:Any, N} , destcrange:: CartesianIndices{N} ,
185
- src:: AbstractOrWrappedGPUArray {<:Any, N} , srccrange:: CartesianIndices{N} ) where {N}
182
+ function Base. copyto! (dest:: AnyGPUArray {<:Any, N} , destcrange:: CartesianIndices{N} ,
183
+ src:: AnyGPUArray {<:Any, N} , srccrange:: CartesianIndices{N} ) where {N}
186
184
shape = size (destcrange)
187
185
if shape != size (srccrange)
188
186
throw (ArgumentError (" Ranges don't match their size. Found: $shape , $(size (srccrange)) " ))
0 commit comments