Skip to content

Commit c6ca41f

Browse files
vchuravymaleadt
authored andcommitted
handle broadcast of Transpose and Adjoint
1 parent 712790f commit c6ca41f

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

src/broadcast.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,36 @@ using Base.Broadcast
22

33
import Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
44

5-
BroadcastStyle(::Type{T}) where T <: GPUArray = ArrayStyle{T}()
5+
# We define a generic `BroadcastStyle` here that should be sufficient for most cases
6+
# dependent packages like `CuArrays` can define their own `BroadcastStyle` allowing
7+
# them to further change or optimize broadcasting.
8+
# TODO: Investigate if we should define out own `GPUArrayStyle{N} <: AbstractArrayStyle{N}`
9+
# NOTE: This uses the specific `T` that was used e.g. `JLArray` or `CLArray` for ArrayStyle, instead
10+
# of using `ArrayStyle{GPUArray}`, this is due to the fact how `similar` works.
11+
BroadcastStyle(::Type{T}) where {T<:GPUArray} = ArrayStyle{T}()
612

13+
# These wrapper types otherwise forget that they are GPU compatible
14+
# Note: Don't directly use ArrayStyle{GPUArray} here since that would mean that `CuArrays`
15+
# customizations no longer take effect.
16+
BroadcastStyle(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
17+
BroadcastStyle(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
18+
19+
# This method is responsible for selection the output type of broadcast
720
function Base.similar(bc::Broadcasted{<:ArrayStyle{GPU}}, ::Type{ElType}) where {GPU <: GPUArray, ElType}
821
similar(GPU, ElType, axes(bc))
922
end
1023

24+
# We purposefully only specialise `copyto!`, dependent packages need to make sure that they can handle:
25+
# - `bc::Broadcast.Broadcasted{Style}`
26+
# - `ex::Broadcast.Extruded`
27+
# - `LinearAlgebra.Transpose{,<:GPUArray}` and `LinearAlgebra.Adjoint{,<:GPUArray}`
28+
# as arguments to a kernel and that they do the right conversion.
29+
#
30+
# This Broadcast can be further customised by:
31+
# - `Broadcast.preprocess(dest::GPUArray, bc::Broadcasted{Nothing})` which allows for a complete transformation
32+
# Broadcasted based on the output type just at the end of the pipeline.
33+
# - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible with `Style`
34+
# For more information see the Base documentation.
1135
@inline function Base.copyto!(dest::GPUArray, bc::Broadcasted{Nothing})
1236
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
1337
bc′ = Broadcast.preprocess(dest, bc)
@@ -20,6 +44,7 @@ end
2044
return dest
2145
end
2246

47+
# TODO: is this still necessary?
2348
function mapidx(f, A::GPUArray, args::NTuple{N, Any}) where N
2449
gpu_call(A, (f, A, args)) do state, f, A, args
2550
ilin = @linearidx(A, state)

0 commit comments

Comments
 (0)