@@ -2,13 +2,56 @@ using Base.Broadcast
2
2
3
3
import Base. Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
4
4
5
- BroadcastStyle (:: Type{T} ) where T <: GPUArray = ArrayStyle {T} ()
5
+ # we define a generic `BroadcastStyle` here that should be sufficient for most cases.
6
+ # dependent packages like `CuArrays` can define their own `BroadcastStyle` allowing
7
+ # them to further change or optimize broadcasting.
8
+ #
9
+ # TODO : investigate if we should define out own `GPUArrayStyle{N} <: AbstractArrayStyle{N}`
10
+ #
11
+ # NOTE: this uses the specific `T` that was used e.g. `JLArray` or `CLArray` for ArrayStyle,
12
+ # instead of using `ArrayStyle{GPUArray}`, due to the fact how `similar` works.
13
+ BroadcastStyle (:: Type{T} ) where {T<: GPUArray } = ArrayStyle {T} ()
6
14
7
- function Base. similar (bc:: Broadcasted{<:ArrayStyle{GPU}} , :: Type{ElType} ) where {GPU <: GPUArray , ElType}
15
+ # These wrapper types otherwise forget that they are GPU compatible
16
+ #
17
+ # NOTE: Don't directly use ArrayStyle{GPUArray} here since that would mean that `CuArrays`
18
+ # customization no longer take effect.
19
+ BroadcastStyle (:: Type{<:LinearAlgebra.Transpose{<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
20
+ BroadcastStyle (:: Type{<:LinearAlgebra.Adjoint{<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
21
+ BroadcastStyle (:: Type{<:SubArray{<:Any,<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
22
+
23
+ backend (:: Type{<:LinearAlgebra.Transpose{<:Any,T}} ) where {T<: GPUArray } = backend (T)
24
+ backend (:: Type{<:LinearAlgebra.Adjoint{<:Any,T}} ) where {T<: GPUArray } = backend (T)
25
+ backend (:: Type{<:SubArray{<:Any,<:Any,T}} ) where {T<: GPUArray } = backend (T)
26
+
27
+ # This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
28
+ # and we could define our methods in terms of Union{GPUArray, WrappedArray{<:Any, <:GPUArray}}
29
+ const GPUDestArray = Union{GPUArray,
30
+ LinearAlgebra. Transpose{<: Any ,<: GPUArray },
31
+ LinearAlgebra. Adjoint{<: Any ,<: GPUArray },
32
+ SubArray{<: Any ,<: Any ,<: GPUArray }}
33
+
34
+ # This method is responsible for selection the output type of broadcast
35
+ function Base. similar (bc:: Broadcasted{<:ArrayStyle{GPU}} , :: Type{ElType} ) where
36
+ {GPU <: GPUArray , ElType}
8
37
similar (GPU, ElType, axes (bc))
9
38
end
10
39
11
- @inline function Base. copyto! (dest:: GPUArray , bc:: Broadcasted{Nothing} )
40
+ # We purposefully only specialize `copyto!`, dependent packages need to make sure that they
41
+ # can handle:
42
+ # - `bc::Broadcast.Broadcasted{Style}`
43
+ # - `ex::Broadcast.Extruded`
44
+ # - `LinearAlgebra.Transpose{,<:GPUArray}` and `LinearAlgebra.Adjoint{,<:GPUArray}`, etc
45
+ # as arguments to a kernel and that they do the right conversion.
46
+ #
47
+ # This Broadcast can be further customize by:
48
+ # - `Broadcast.preprocess(dest::GPUArray, bc::Broadcasted{Nothing})` which allows for a
49
+ # complete transformation based on the output type just at the end of the pipeline.
50
+ # - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible
51
+ # with `Style`
52
+ #
53
+ # For more information see the Base documentation.
54
+ @inline function Base. copyto! (dest:: GPUDestArray , bc:: Broadcasted{Nothing} )
12
55
axes (dest) == axes (bc) || Broadcast. throwdm (axes (dest), axes (bc))
13
56
bc′ = Broadcast. preprocess (dest, bc)
14
57
gpu_call (dest, (dest, bc′)) do state, dest, bc′
20
63
return dest
21
64
end
22
65
66
+ # Base defines this method as a performance optimization, but we don't know how to do
67
+ # `fill!` in general for all `GPUDestArray` so we just go straight to the fallback
68
+ @inline Base. copyto! (dest:: GPUDestArray , bc:: Broadcasted{<:Broadcast.AbstractArrayStyle{0}} ) =
69
+ copyto! (dest, convert (Broadcasted{Nothing}, bc))
70
+
71
+ # TODO : is this still necessary?
23
72
function mapidx (f, A:: GPUArray , args:: NTuple{N, Any} ) where N
24
73
gpu_call (A, (f, A, args)) do state, f, A, args
25
74
ilin = @linearidx (A, state)
0 commit comments