@@ -2,12 +2,36 @@ using Base.Broadcast
2
2
3
3
import Base. Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
4
4
5
- BroadcastStyle (:: Type{T} ) where T <: GPUArray = ArrayStyle {T} ()
5
+ # We define a generic `BroadcastStyle` here that should be sufficient for most cases
6
+ # dependent packages like `CuArrays` can define their own `BroadcastStyle` allowing
7
+ # them to further change or optimize broadcasting.
8
+ # TODO : Investigate if we should define out own `GPUArrayStyle{N} <: AbstractArrayStyle{N}`
9
+ # NOTE: This uses the specific `T` that was used e.g. `JLArray` or `CLArray` for ArrayStyle, instead
10
+ # of using `ArrayStyle{GPUArray}`, this is due to the fact how `similar` works.
11
+ BroadcastStyle (:: Type{T} ) where {T<: GPUArray } = ArrayStyle {T} ()
6
12
13
+ # These wrapper types otherwise forget that they are GPU compatible
14
+ # Note: Don't directly use ArrayStyle{GPUArray} here since that would mean that `CuArrays`
15
+ # customizations no longer take effect.
16
+ BroadcastStyle (:: Type{<:LinearAlgebra.Transpose{<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
17
+ BroadcastStyle (:: Type{<:LinearAlgebra.Adjoint{<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
18
+
19
+ # This method is responsible for selection the output type of broadcast
7
20
function Base. similar (bc:: Broadcasted{<:ArrayStyle{GPU}} , :: Type{ElType} ) where {GPU <: GPUArray , ElType}
8
21
similar (GPU, ElType, axes (bc))
9
22
end
10
23
24
+ # We purposefully only specialise `copyto!`, dependent packages need to make sure that they can handle:
25
+ # - `bc::Broadcast.Broadcasted{Style}`
26
+ # - `ex::Broadcast.Extruded`
27
+ # - `LinearAlgebra.Transpose{,<:GPUArray}` and `LinearAlgebra.Adjoint{,<:GPUArray}`
28
+ # as arguments to a kernel and that they do the right conversion.
29
+ #
30
+ # This Broadcast can be further customised by:
31
+ # - `Broadcast.preprocess(dest::GPUArray, bc::Broadcasted{Nothing})` which allows for a complete transformation
32
+ # Broadcasted based on the output type just at the end of the pipeline.
33
+ # - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible with `Style`
34
+ # For more information see the Base documentation.
11
35
@inline function Base. copyto! (dest:: GPUArray , bc:: Broadcasted{Nothing} )
12
36
axes (dest) == axes (bc) || Broadcast. throwdm (axes (dest), axes (bc))
13
37
bc′ = Broadcast. preprocess (dest, bc)
20
44
return dest
21
45
end
22
46
47
+ # TODO : is this still necessary?
23
48
function mapidx (f, A:: GPUArray , args:: NTuple{N, Any} ) where N
24
49
gpu_call (A, (f, A, args)) do state, f, A, args
25
50
ilin = @linearidx (A, state)
0 commit comments