Skip to content

Commit 6cec87e

Browse files
authored
Merge pull request #148 from JuliaGPU/vc/trbc
Support `Transpose` and `Adjoint` in broadcast better
2 parents 712790f + 5132df8 commit 6cec87e

File tree

6 files changed

+86
-26
lines changed

6 files changed

+86
-26
lines changed

src/abstract_gpu_interface.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ end
106106
# CUDAnative.__syncthreads()
107107
# end
108108

109-
110-
109+
abstract type GPUBackend end
110+
backend(::Type{T}) where T = error("Can't choose GPU backend for $T")
111111

112112
"""
113113
gpu_call(kernel::Function, A::GPUArray, args::Tuple, configuration = length(A))
@@ -124,7 +124,7 @@ Optionally, a launch configuration can be supplied in the following way:
124124
2) Pass a tuple of integer tuples to define blocks and threads per blocks!
125125
126126
"""
127-
function gpu_call(kernel, A::GPUArray, args::Tuple, configuration = length(A))
127+
function gpu_call(kernel, A::AbstractArray, args::Tuple, configuration = length(A))
128128
ITuple = NTuple{N, Integer} where N
129129
# If is a single integer, we assume it to be the global size / total number of threads one wants to launch
130130
thread_blocks = if isa(configuration, Integer)
@@ -148,8 +148,8 @@ function gpu_call(kernel, A::GPUArray, args::Tuple, configuration = length(A))
148148
`linear_index` will be inbetween 1:prod((blocks..., threads...))
149149
""")
150150
end
151-
_gpu_call(kernel, A, args, thread_blocks)
151+
_gpu_call(backend(typeof(A)), kernel, A, args, thread_blocks)
152152
end
153153

154154
# Internal GPU call function, that needs to be overloaded by the backends.
155-
_gpu_call(f, A, args, thread_blocks) = error("Not implemented")
155+
_gpu_call(::Any, f, A, args, thread_blocks) = error("Not implemented")

src/abstractarray.jl

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@ function deserialize(s::AbstractSerializer, ::Type{T}) where T <: GPUArray
3030
T(A)
3131
end
3232

33-
@inline unpack_buffer(x) = x
34-
@inline unpack_buffer(x::GPUArray) = pointer(x)
35-
@inline unpack_buffer(x::Ref{<: GPUArray}) = unpack_buffer(x[])
36-
3733
function to_cartesian(A, indices::Tuple)
3834
start = CartesianIndex(ntuple(length(indices)) do i
3935
val = indices[i]
@@ -56,22 +52,24 @@ end
5652

5753
## showing
5854

59-
for (atype, op) in
60-
[(:(GPUArray), :(Array)),
61-
(:(LinearAlgebra.Adjoint{<:Any,<:GPUArray}), :(x->LinearAlgebra.adjoint(Array(parent(x))))),
62-
(:(LinearAlgebra.Transpose{<:Any,<:GPUArray}), :(x->LinearAlgebra.transpose(Array(parent(x)))))]
55+
for (AT, f) in
56+
(GPUArray => Array,
57+
LinearAlgebra.Adjoint{<:Any,<:GPUArray} => x->LinearAlgebra.adjoint(Array(parent(x))),
58+
LinearAlgebra.Transpose{<:Any,<:GPUArray} => x->LinearAlgebra.transpose(Array(parent(x))),
59+
SubArray{<:Any,<:Any,<:GPUArray} => x->SubArray(Array(parent(x)), parentindices(x))
60+
)
6361
@eval begin
6462
# for display
65-
Base.print_array(io::IO, X::($atype)) =
66-
Base.print_array(io,($op)(X))
63+
Base.print_array(io::IO, X::$AT) =
64+
Base.print_array(io,$f(X))
6765

6866
# for show
69-
Base._show_nonempty(io::IO, X::($atype), prefix::String) =
70-
Base._show_nonempty(io,($op)(X),prefix)
71-
Base._show_empty(io::IO, X::($atype)) =
72-
Base._show_empty(io,($op)(X))
73-
Base.show_vector(io::IO, v::($atype), args...) =
74-
Base.show_vector(io,($op)(v),args...)
67+
Base._show_nonempty(io::IO, X::$AT, prefix::String) =
68+
Base._show_nonempty(io,$f(X),prefix)
69+
Base._show_empty(io::IO, X::$AT) =
70+
Base._show_empty(io,$f(X))
71+
Base.show_vector(io::IO, v::$AT, args...) =
72+
Base.show_vector(io,$f(v),args...)
7573
end
7674
end
7775

src/array.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ function JLArray{T, N}(size::NTuple{N, Integer}) where {T, N}
2121
JLArray{T, N}(Array{T, N}(undef, size), size)
2222
end
2323

24+
struct JLBackend <: GPUBackend end
25+
backend(::Type{<:JLArray}) = JLBackend()
2426

2527
## getters
2628

@@ -120,7 +122,7 @@ function AbstractDeviceArray(ptr::Array, shape::Vararg{Integer, N}) where N
120122
reshape(ptr, shape)
121123
end
122124

123-
function _gpu_call(f, A::JLArray, args::Tuple, blocks_threads::Tuple{T, T}) where T <: NTuple{N, Integer} where N
125+
function _gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{T, T}) where T <: NTuple{N, Integer} where N
124126
blocks, threads = blocks_threads
125127
idx = ntuple(i-> 1, length(blocks))
126128
blockdim = blocks

src/broadcast.jl

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,56 @@ using Base.Broadcast
22

33
import Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
44

5-
BroadcastStyle(::Type{T}) where T <: GPUArray = ArrayStyle{T}()
5+
# we define a generic `BroadcastStyle` here that should be sufficient for most cases.
6+
# dependent packages like `CuArrays` can define their own `BroadcastStyle` allowing
7+
# them to further change or optimize broadcasting.
8+
#
9+
# TODO: investigate if we should define out own `GPUArrayStyle{N} <: AbstractArrayStyle{N}`
10+
#
11+
# NOTE: this uses the specific `T` that was used e.g. `JLArray` or `CLArray` for ArrayStyle,
12+
# instead of using `ArrayStyle{GPUArray}`, due to the fact how `similar` works.
13+
BroadcastStyle(::Type{T}) where {T<:GPUArray} = ArrayStyle{T}()
614

7-
function Base.similar(bc::Broadcasted{<:ArrayStyle{GPU}}, ::Type{ElType}) where {GPU <: GPUArray, ElType}
15+
# These wrapper types otherwise forget that they are GPU compatible
16+
#
17+
# NOTE: Don't directly use ArrayStyle{GPUArray} here since that would mean that `CuArrays`
18+
# customization no longer take effect.
19+
BroadcastStyle(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
20+
BroadcastStyle(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
21+
BroadcastStyle(::Type{<:SubArray{<:Any,<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
22+
23+
backend(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where {T<:GPUArray} = backend(T)
24+
backend(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where {T<:GPUArray} = backend(T)
25+
backend(::Type{<:SubArray{<:Any,<:Any,T}}) where {T<:GPUArray} = backend(T)
26+
27+
# This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
28+
# and we could define our methods in terms of Union{GPUArray, WrappedArray{<:Any, <:GPUArray}}
29+
const GPUDestArray = Union{GPUArray,
30+
LinearAlgebra.Transpose{<:Any,<:GPUArray},
31+
LinearAlgebra.Adjoint{<:Any,<:GPUArray},
32+
SubArray{<:Any,<:Any,<:GPUArray}}
33+
34+
# This method is responsible for selection the output type of broadcast
35+
function Base.similar(bc::Broadcasted{<:ArrayStyle{GPU}}, ::Type{ElType}) where
36+
{GPU <: GPUArray, ElType}
837
similar(GPU, ElType, axes(bc))
938
end
1039

11-
@inline function Base.copyto!(dest::GPUArray, bc::Broadcasted{Nothing})
40+
# We purposefully only specialize `copyto!`, dependent packages need to make sure that they
41+
# can handle:
42+
# - `bc::Broadcast.Broadcasted{Style}`
43+
# - `ex::Broadcast.Extruded`
44+
# - `LinearAlgebra.Transpose{,<:GPUArray}` and `LinearAlgebra.Adjoint{,<:GPUArray}`, etc
45+
# as arguments to a kernel and that they do the right conversion.
46+
#
47+
# This Broadcast can be further customize by:
48+
# - `Broadcast.preprocess(dest::GPUArray, bc::Broadcasted{Nothing})` which allows for a
49+
# complete transformation based on the output type just at the end of the pipeline.
50+
# - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible
51+
# with `Style`
52+
#
53+
# For more information see the Base documentation.
54+
@inline function Base.copyto!(dest::GPUDestArray, bc::Broadcasted{Nothing})
1255
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
1356
bc′ = Broadcast.preprocess(dest, bc)
1457
gpu_call(dest, (dest, bc′)) do state, dest, bc′
@@ -20,6 +63,12 @@ end
2063
return dest
2164
end
2265

66+
# Base defines this method as a performance optimization, but we don't know how to do
67+
# `fill!` in general for all `GPUDestArray` so we just go straight to the fallback
68+
@inline Base.copyto!(dest::GPUDestArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}}) =
69+
copyto!(dest, convert(Broadcasted{Nothing}, bc))
70+
71+
# TODO: is this still necessary?
2372
function mapidx(f, A::GPUArray, args::NTuple{N, Any}) where N
2473
gpu_call(A, (f, A, args)) do state, f, A, args
2574
ilin = @linearidx(A, state)

src/mapreduce.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, A; init = false)
77
Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, A; init = true)
8-
Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, A; init = 0))
8+
9+
Base.any(f::Function, A::GPUArray) = mapreduce(f, |, A; init = false)
10+
Base.all(f::Function, A::GPUArray) = mapreduce(f, &, A; init = true)
11+
Base.count(pred::Function, A::GPUArray) = Int(mapreduce(pred, +, A; init = 0))
912

1013
Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, A, B; init = true))
1114

src/testsuite/broadcasting.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ function broadcasting(AT)
5353
end
5454
end
5555

56+
@testset "Adjoint and Transpose" begin
57+
A = AT(rand(ET, N))
58+
A' .= ET(2)
59+
@test all(x->x==ET(2), A)
60+
transpose(A) .= ET(1)
61+
@test all(x->x==ET(1), A)
62+
end
63+
5664
############
5765
# issue #27
5866
@test compare((a, b)-> a .+ b, AT, rand(ET, 4, 5, 3), rand(ET, 1, 5, 3))

0 commit comments

Comments
 (0)