@@ -2,44 +2,54 @@ using Base.Broadcast
2
2
3
3
import Base. Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
4
4
5
- # We define a generic `BroadcastStyle` here that should be sufficient for most cases
5
+ # we define a generic `BroadcastStyle` here that should be sufficient for most cases.
6
6
# dependent packages like `CuArrays` can define their own `BroadcastStyle` allowing
7
7
# them to further change or optimize broadcasting.
8
- # TODO : Investigate if we should define out own `GPUArrayStyle{N} <: AbstractArrayStyle{N}`
9
- # NOTE: This uses the specific `T` that was used e.g. `JLArray` or `CLArray` for ArrayStyle, instead
10
- # of using `ArrayStyle{GPUArray}`, this is due to the fact how `similar` works.
8
+ #
9
+ # TODO : investigate if we should define out own `GPUArrayStyle{N} <: AbstractArrayStyle{N}`
10
+ #
11
+ # NOTE: this uses the specific `T` that was used e.g. `JLArray` or `CLArray` for ArrayStyle,
12
+ # instead of using `ArrayStyle{GPUArray}`, due to the fact how `similar` works.
11
13
BroadcastStyle (:: Type{T} ) where {T<: GPUArray } = ArrayStyle {T} ()
12
14
13
15
# These wrapper types otherwise forget that they are GPU compatible
14
- # Note: Don't directly use ArrayStyle{GPUArray} here since that would mean that `CuArrays`
15
- # customizations no longer take effect.
16
+ #
17
+ # NOTE: Don't directly use ArrayStyle{GPUArray} here since that would mean that `CuArrays`
18
+ # customization no longer take effect.
16
19
BroadcastStyle (:: Type{<:LinearAlgebra.Transpose{<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
17
20
BroadcastStyle (:: Type{<:LinearAlgebra.Adjoint{<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
21
+ BroadcastStyle (:: Type{<:SubArray{<:Any,<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
18
22
19
23
backend (:: Type{<:LinearAlgebra.Transpose{<:Any,T}} ) where {T<: GPUArray } = backend (T)
20
24
backend (:: Type{<:LinearAlgebra.Adjoint{<:Any,T}} ) where {T<: GPUArray } = backend (T)
25
+ backend (:: Type{<:SubArray{<:Any,<:Any,T}} ) where {T<: GPUArray } = backend (T)
21
26
22
27
# This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
23
28
# and we could define our methods in terms of Union{GPUArray, WrappedArray{<:Any, <:GPUArray}}
24
29
const GPUDestArray = Union{GPUArray,
25
30
LinearAlgebra. Transpose{<: Any ,<: GPUArray },
26
- LinearAlgebra. Adjoint{<: Any ,<: GPUArray }}
31
+ LinearAlgebra. Adjoint{<: Any ,<: GPUArray },
32
+ SubArray{<: Any ,<: Any ,<: GPUArray }}
27
33
28
34
# This method is responsible for selection the output type of broadcast
29
- function Base. similar (bc:: Broadcasted{<:ArrayStyle{GPU}} , :: Type{ElType} ) where {GPU <: GPUArray , ElType}
35
+ function Base. similar (bc:: Broadcasted{<:ArrayStyle{GPU}} , :: Type{ElType} ) where
36
+ {GPU <: GPUArray , ElType}
30
37
similar (GPU, ElType, axes (bc))
31
38
end
32
39
33
- # We purposefully only specialise `copyto!`, dependent packages need to make sure that they can handle:
40
+ # We purposefully only specialize `copyto!`, dependent packages need to make sure that they
41
+ # can handle:
34
42
# - `bc::Broadcast.Broadcasted{Style}`
35
43
# - `ex::Broadcast.Extruded`
36
- # - `LinearAlgebra.Transpose{,<:GPUArray}` and `LinearAlgebra.Adjoint{,<:GPUArray}`
37
- # as arguments to a kernel and that they do the right conversion.
44
+ # - `LinearAlgebra.Transpose{,<:GPUArray}` and `LinearAlgebra.Adjoint{,<:GPUArray}`, etc
45
+ # as arguments to a kernel and that they do the right conversion.
46
+ #
47
+ # This Broadcast can be further customize by:
48
+ # - `Broadcast.preprocess(dest::GPUArray, bc::Broadcasted{Nothing})` which allows for a
49
+ # complete transformation based on the output type just at the end of the pipeline.
50
+ # - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible
51
+ # with `Style`
38
52
#
39
- # This Broadcast can be further customised by:
40
- # - `Broadcast.preprocess(dest::GPUArray, bc::Broadcasted{Nothing})` which allows for a complete transformation
41
- # Broadcasted based on the output type just at the end of the pipeline.
42
- # - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible with `Style`
43
53
# For more information see the Base documentation.
44
54
@inline function Base. copyto! (dest:: GPUDestArray , bc:: Broadcasted{Nothing} )
45
55
axes (dest) == axes (bc) || Broadcast. throwdm (axes (dest), axes (bc))
53
63
return dest
54
64
end
55
65
56
- # Base defines this method as a performance optimization, but we don't know how
57
- # to do `fill!` in general for all `GPUDestArray` so we just straight go to the fallback
58
- @inline function Base. copyto! (dest:: GPUDestArray , bc:: Broadcasted{<:Broadcast.AbstractArrayStyle{0}} )
59
- return copyto! (dest, convert (Broadcasted{Nothing}, bc))
60
- end
66
+ # Base defines this method as a performance optimization, but we don't know how to do
67
+ # `fill!` in general for all `GPUDestArray` so we just go straight to the fallback
68
+ @inline Base. copyto! (dest:: GPUDestArray , bc:: Broadcasted{<:Broadcast.AbstractArrayStyle{0}} ) =
69
+ copyto! (dest, convert (Broadcasted{Nothing}, bc))
61
70
62
71
# TODO : is this still necessary?
63
72
function mapidx (f, A:: GPUArray , args:: NTuple{N, Any} ) where N
0 commit comments