Skip to content

Commit 922e538

Browse files
authored
Merge pull request #323 from JuliaGPU/tb/any_alias
Mapreduce on array wrappers
2 parents 1f38557 + ad2a65b commit 922e538

File tree

7 files changed

+36
-38
lines changed

7 files changed

+36
-38
lines changed

src/host/abstractarray.jl

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ backend(::Type{<:AbstractGPUDevice}) = error("Not implemented") # COV_EXCL_LINE
2323

2424
const WrappedGPUArray{T,N} = WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{T,N}}
2525

26-
const AbstractOrWrappedGPUArray{T,N} =
27-
Union{AbstractGPUArray{T,N},
28-
WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{T,N}}}
26+
const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}}
2927

3028

3129
# input/output
@@ -51,21 +49,21 @@ convert_to_cpu(xs) = adapt(Array, xs)
5149
## showing
5250

5351
# display
54-
Base.print_array(io::IO, X::AbstractOrWrappedGPUArray) =
52+
Base.print_array(io::IO, X::AnyGPUArray) =
5553
Base.print_array(io, convert_to_cpu(X))
5654

5755
# show
58-
Base._show_nonempty(io::IO, X::AbstractOrWrappedGPUArray, prefix::String) =
56+
Base._show_nonempty(io::IO, X::AnyGPUArray, prefix::String) =
5957
Base._show_nonempty(io, convert_to_cpu(X), prefix)
60-
Base._show_empty(io::IO, X::AbstractOrWrappedGPUArray) =
58+
Base._show_empty(io::IO, X::AnyGPUArray) =
6159
Base._show_empty(io, convert_to_cpu(X))
62-
Base.show_vector(io::IO, v::AbstractOrWrappedGPUArray, args...) =
60+
Base.show_vector(io::IO, v::AnyGPUArray, args...) =
6361
Base.show_vector(io, convert_to_cpu(v), args...)
6462

6563
## collect to CPU (discarding wrapper type)
6664

6765
collect_to_cpu(xs::AbstractArray) = collect(convert_to_cpu(xs))
68-
Base.collect(X::AbstractOrWrappedGPUArray) = collect_to_cpu(X)
66+
Base.collect(X::AnyGPUArray) = collect_to_cpu(X)
6967

7068

7169
# memory copying
@@ -75,9 +73,9 @@ Base.collect(X::AbstractOrWrappedGPUArray) = collect_to_cpu(X)
7573
# expects the GPU array type to have linear `copyto!` methods (i.e. accepting an integer
7674
# offset and length) from and to CPU arrays and between GPU arrays.
7775

78-
for (D, S) in ((AbstractOrWrappedGPUArray, Array),
79-
(Array, AbstractOrWrappedGPUArray),
80-
(AbstractOrWrappedGPUArray, AbstractOrWrappedGPUArray))
76+
for (D, S) in ((AnyGPUArray, Array),
77+
(Array, AnyGPUArray),
78+
(AnyGPUArray, AnyGPUArray))
8179
@eval begin
8280
function Base.copyto!(dest::$D{<:Any, N}, rdest::UnitRange,
8381
src::$S{<:Any, N}, ssrc::UnitRange) where {N}
@@ -112,8 +110,8 @@ function linear_copy_kernel!(ctx::AbstractKernelContext, dest, dstart, src, ssta
112110
return
113111
end
114112

115-
function Base.copyto!(dest::AbstractOrWrappedGPUArray, dstart::Integer,
116-
src::AbstractOrWrappedGPUArray, sstart::Integer, n::Integer)
113+
function Base.copyto!(dest::AnyGPUArray, dstart::Integer,
114+
src::AnyGPUArray, sstart::Integer, n::Integer)
117115
n == 0 && return dest
118116
n < 0 && throw(ArgumentError(string("tried to copy n=", n, " elements, but n should be nonnegative")))
119117
destinds, srcinds = LinearIndices(dest), LinearIndices(src)
@@ -152,7 +150,7 @@ end
152150
# to quickly perform these very lightweight conversions
153151

154152
function Base.copyto!(dest::Array{T}, dstart::Integer,
155-
src::AbstractOrWrappedGPUArray{U}, sstart::Integer,
153+
src::AnyGPUArray{U}, sstart::Integer,
156154
n::Integer) where {T,U}
157155
n == 0 && return dest
158156
temp = Vector{U}(undef, n)
@@ -161,7 +159,7 @@ function Base.copyto!(dest::Array{T}, dstart::Integer,
161159
return dest
162160
end
163161

164-
function Base.copyto!(dest::AbstractOrWrappedGPUArray{T}, dstart::Integer,
162+
function Base.copyto!(dest::AnyGPUArray{T}, dstart::Integer,
165163
src::Array{U}, sstart::Integer, n::Integer) where {T,U}
166164
n == 0 && return dest
167165
temp = Vector{T}(undef, n)
@@ -181,8 +179,8 @@ function cartesian_copy_kernel!(ctx::AbstractKernelContext, dest, dest_offsets,
181179
return
182180
end
183181

184-
function Base.copyto!(dest::AbstractOrWrappedGPUArray{<:Any, N}, destcrange::CartesianIndices{N},
185-
src::AbstractOrWrappedGPUArray{<:Any, N}, srccrange::CartesianIndices{N}) where {N}
182+
function Base.copyto!(dest::AnyGPUArray{<:Any, N}, destcrange::CartesianIndices{N},
183+
src::AnyGPUArray{<:Any, N}, srccrange::CartesianIndices{N}) where {N}
186184
shape = size(destcrange)
187185
if shape != size(srccrange)
188186
throw(ArgumentError("Ranges don't match their size. Found: $shape, $(size(srccrange))"))

src/host/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Base.Broadcast
66

77
import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle
88

9-
const BroadcastGPUArray{T} = Union{AbstractOrWrappedGPUArray{T},
9+
const BroadcastGPUArray{T} = Union{AnyGPUArray{T},
1010
Base.RefValue{<:AbstractGPUArray{T}}}
1111

1212
"""

src/host/construction.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# constructors and conversions
22

3-
function Base.fill!(A::AbstractOrWrappedGPUArray{T}, x) where T
3+
function Base.fill!(A::AnyGPUArray{T}, x) where T
44
length(A) == 0 && return A
55
gpu_call(A, convert(T, x)) do ctx, a, val
66
idx = @linearidx(a)
@@ -18,16 +18,16 @@ function uniformscaling_kernel(ctx::AbstractKernelContext, res::AbstractArray{T}
1818
return
1919
end
2020

21-
function (T::Type{<: AbstractOrWrappedGPUArray{U}})(s::UniformScaling, dims::Dims{2}) where {U}
21+
function (T::Type{<: AnyGPUArray{U}})(s::UniformScaling, dims::Dims{2}) where {U}
2222
res = similar(T, dims)
2323
fill!(res, zero(U))
2424
gpu_call(uniformscaling_kernel, res, size(res, 1), s; total_threads=minimum(dims))
2525
res
2626
end
2727

28-
(T::Type{<: AbstractOrWrappedGPUArray})(s::UniformScaling{U}, dims::Dims{2}) where U = T{U}(s, dims)
28+
(T::Type{<: AnyGPUArray})(s::UniformScaling{U}, dims::Dims{2}) where U = T{U}(s, dims)
2929

30-
(T::Type{<: AbstractOrWrappedGPUArray})(s::UniformScaling, m::Integer, n::Integer) = T(s, Dims((m, n)))
30+
(T::Type{<: AnyGPUArray})(s::UniformScaling, m::Integer, n::Integer) = T(s, Dims((m, n)))
3131

3232
function Base.copyto!(A::AbstractGPUMatrix{T}, s::UniformScaling) where T
3333
fill!(A, zero(T))

src/host/linalg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ end
9999

100100
## matrix multiplication
101101

102-
function generic_matmatmul!(C::AbstractOrWrappedGPUArray{R}, A::AbstractOrWrappedGPUArray{T}, B::AbstractOrWrappedGPUArray{S}, a::Number, b::Number) where {T,S,R}
102+
function generic_matmatmul!(C::AnyGPUArray{R}, A::AnyGPUArray{T}, B::AnyGPUArray{S}, a::Number, b::Number) where {T,S,R}
103103
if size(A,2) != size(B,1)
104104
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
105105
end

src/host/mapreduce.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ const AbstractArrayOrBroadcasted = Union{AbstractArray,Broadcast.Broadcasted}
44

55
# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional
66
# argument `init` value to avoid eager initialization of `R` (if set to something).
7-
mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArrayOrBroadcasted;
7+
mapreducedim!(f, op, R::AnyGPUArray, A::AbstractArrayOrBroadcasted;
88
init=nothing) = error("Not implemented") # COV_EXCL_LINE
99
# resolve ambiguities
10-
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
11-
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::Broadcast.Broadcasted) = mapreducedim!(f, op, R, A)
10+
Base.mapreducedim!(f, op, R::AnyGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
11+
Base.mapreducedim!(f, op, R::AnyGPUArray, A::Broadcast.Broadcasted) = mapreducedim!(f, op, R, A)
1212

1313
neutral_element(op, T) =
1414
error("""GPUArrays.jl needs to know the neutral element for your operator `$op`.
@@ -24,7 +24,7 @@ neutral_element(::typeof(Base.min), T) = typemax(T)
2424
neutral_element(::typeof(Base.max), T) = typemin(T)
2525

2626
# resolve ambiguities
27-
Base.mapreduce(f, op, A::AbstractGPUArray, As::AbstractArrayOrBroadcasted...;
27+
Base.mapreduce(f, op, A::AnyGPUArray, As::AbstractArrayOrBroadcasted...;
2828
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
2929
Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::AbstractArrayOrBroadcasted...;
3030
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
@@ -68,24 +68,24 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
6868
end
6969
end
7070

71-
Base.any(A::AbstractGPUArray{Bool}) = mapreduce(identity, |, A)
72-
Base.all(A::AbstractGPUArray{Bool}) = mapreduce(identity, &, A)
71+
Base.any(A::AnyGPUArray{Bool}) = mapreduce(identity, |, A)
72+
Base.all(A::AnyGPUArray{Bool}) = mapreduce(identity, &, A)
7373

74-
Base.any(f::Function, A::AbstractGPUArray) = mapreduce(f, |, A)
75-
Base.all(f::Function, A::AbstractGPUArray) = mapreduce(f, &, A)
74+
Base.any(f::Function, A::AnyGPUArray) = mapreduce(f, |, A)
75+
Base.all(f::Function, A::AnyGPUArray) = mapreduce(f, &, A)
7676

77-
Base.count(pred::Function, A::AbstractGPUArray; dims=:) =
77+
Base.count(pred::Function, A::AnyGPUArray; dims=:) =
7878
mapreduce(pred, Base.add_sum, A; init=0, dims=dims)
7979

80-
Base.:(==)(A::AbstractGPUArray, B::AbstractGPUArray) = Bool(mapreduce(==, &, A, B))
80+
Base.:(==)(A::AnyGPUArray, B::AnyGPUArray) = Bool(mapreduce(==, &, A, B))
8181

8282
# avoid calling into `initarray!`
8383
for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
8484
(:maximum, :(Base.max)), (:minimum, :(Base.min)),
8585
(:all, :&), (:any, :|)]
8686
fname! = Symbol(fname, '!')
8787
@eval begin
88-
Base.$(fname!)(f::Function, r::AbstractGPUArray, A::AbstractGPUArray{T}) where T =
88+
Base.$(fname!)(f::Function, r::AnyGPUArray, A::AnyGPUArray{T}) where T =
8989
GPUArrays.mapreducedim!(f, $(op), r, A; init=neutral_element($(op), T))
9090
end
9191
end

src/host/math.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Base mathematical operations
22

3-
function Base.clamp!(A::AbstractOrWrappedGPUArray, low, high)
3+
function Base.clamp!(A::AnyGPUArray, low, high)
44
gpu_call(A, low, high) do ctx, A, low, high
55
I = @cartesianidx A
66
A[I] = clamp(A[I], low, high)

src/host/random.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ struct RNG <: AbstractRNG
6767
end
6868

6969
# return an instance of GPUArrays.RNG suitable for the requested array type
70-
default_rng(::Type{<:AbstractGPUArray}) = error("Not implemented") # COV_EXCL_LINE
70+
default_rng(::Type{<:AnyGPUArray}) = error("Not implemented") # COV_EXCL_LINE
7171

7272
make_seed(rng::RNG) = make_seed(rng, rand(UInt))
7373
function make_seed(rng::RNG, n::Integer)
@@ -81,7 +81,7 @@ function Random.seed!(rng::RNG, seed::Vector{UInt32})
8181
return
8282
end
8383

84-
function Random.rand!(rng::RNG, A::AbstractGPUArray{T}) where T <: Number
84+
function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
8585
gpu_call(A, rng.state) do ctx, a, randstates
8686
idx = linear_index(ctx)
8787
idx > length(a) && return
@@ -91,7 +91,7 @@ function Random.rand!(rng::RNG, A::AbstractGPUArray{T}) where T <: Number
9191
A
9292
end
9393

94-
function Random.randn!(rng::RNG, A::AbstractGPUArray{T}) where T <: Number
94+
function Random.randn!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
9595
threads = (length(A) - 1) ÷ 2 + 1
9696
length(A) == 0 && return
9797
gpu_call(A, rng.state; total_threads = threads) do ctx, a, randstates

0 commit comments

Comments
 (0)