Skip to content

Commit 460ea24

Browse files
committed
support for subarray, and clean-ups.
1 parent 1518340 commit 460ea24

File tree

1 file changed

+29
-20
lines changed

1 file changed

+29
-20
lines changed

src/broadcast.jl

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,54 @@ using Base.Broadcast
22

33
import Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
44

5-
# We define a generic `BroadcastStyle` here that should be sufficient for most cases
5+
# we define a generic `BroadcastStyle` here that should be sufficient for most cases.
66
# dependent packages like `CuArrays` can define their own `BroadcastStyle` allowing
77
# them to further change or optimize broadcasting.
8-
# TODO: Investigate if we should define out own `GPUArrayStyle{N} <: AbstractArrayStyle{N}`
9-
# NOTE: This uses the specific `T` that was used e.g. `JLArray` or `CLArray` for ArrayStyle, instead
10-
# of using `ArrayStyle{GPUArray}`, this is due to the fact how `similar` works.
8+
#
9+
# TODO: investigate if we should define out own `GPUArrayStyle{N} <: AbstractArrayStyle{N}`
10+
#
11+
# NOTE: this uses the specific `T` that was used e.g. `JLArray` or `CLArray` for ArrayStyle,
12+
# instead of using `ArrayStyle{GPUArray}`, due to the fact how `similar` works.
1113
BroadcastStyle(::Type{T}) where {T<:GPUArray} = ArrayStyle{T}()
1214

1315
# These wrapper types otherwise forget that they are GPU compatible
14-
# Note: Don't directly use ArrayStyle{GPUArray} here since that would mean that `CuArrays`
15-
# customizations no longer take effect.
16+
#
17+
# NOTE: Don't directly use ArrayStyle{GPUArray} here since that would mean that `CuArrays`
18+
# customization no longer take effect.
1619
BroadcastStyle(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
1720
BroadcastStyle(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
21+
BroadcastStyle(::Type{<:SubArray{<:Any,<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
1822

1923
backend(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where {T<:GPUArray} = backend(T)
2024
backend(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where {T<:GPUArray} = backend(T)
25+
backend(::Type{<:SubArray{<:Any,<:Any,T}}) where {T<:GPUArray} = backend(T)
2126

2227
# This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
2328
# and we could define our methods in terms of Union{GPUArray, WrappedArray{<:Any, <:GPUArray}}
2429
const GPUDestArray = Union{GPUArray,
2530
LinearAlgebra.Transpose{<:Any,<:GPUArray},
26-
LinearAlgebra.Adjoint{<:Any,<:GPUArray}}
31+
LinearAlgebra.Adjoint{<:Any,<:GPUArray},
32+
SubArray{<:Any,<:Any,<:GPUArray}}
2733

2834
# This method is responsible for selection the output type of broadcast
29-
function Base.similar(bc::Broadcasted{<:ArrayStyle{GPU}}, ::Type{ElType}) where {GPU <: GPUArray, ElType}
35+
function Base.similar(bc::Broadcasted{<:ArrayStyle{GPU}}, ::Type{ElType}) where
36+
{GPU <: GPUArray, ElType}
3037
similar(GPU, ElType, axes(bc))
3138
end
3239

33-
# We purposefully only specialise `copyto!`, dependent packages need to make sure that they can handle:
40+
# We purposefully only specialize `copyto!`, dependent packages need to make sure that they
41+
# can handle:
3442
# - `bc::Broadcast.Broadcasted{Style}`
3543
# - `ex::Broadcast.Extruded`
36-
# - `LinearAlgebra.Transpose{,<:GPUArray}` and `LinearAlgebra.Adjoint{,<:GPUArray}`
37-
# as arguments to a kernel and that they do the right conversion.
44+
# - `LinearAlgebra.Transpose{,<:GPUArray}` and `LinearAlgebra.Adjoint{,<:GPUArray}`, etc
45+
# as arguments to a kernel and that they do the right conversion.
46+
#
47+
# This Broadcast can be further customize by:
48+
# - `Broadcast.preprocess(dest::GPUArray, bc::Broadcasted{Nothing})` which allows for a
49+
# complete transformation based on the output type just at the end of the pipeline.
50+
# - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible
51+
# with `Style`
3852
#
39-
# This Broadcast can be further customised by:
40-
# - `Broadcast.preprocess(dest::GPUArray, bc::Broadcasted{Nothing})` which allows for a complete transformation
41-
# Broadcasted based on the output type just at the end of the pipeline.
42-
# - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible with `Style`
4353
# For more information see the Base documentation.
4454
@inline function Base.copyto!(dest::GPUDestArray, bc::Broadcasted{Nothing})
4555
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
@@ -53,11 +63,10 @@ end
5363
return dest
5464
end
5565

56-
# Base defines this method as a performance optimization, but we don't know how
57-
# to do `fill!` in general for all `GPUDestArray` so we just straight go to the fallback
58-
@inline function Base.copyto!(dest::GPUDestArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}})
59-
return copyto!(dest, convert(Broadcasted{Nothing}, bc))
60-
end
66+
# Base defines this method as a performance optimization, but we don't know how to do
67+
# `fill!` in general for all `GPUDestArray` so we just go straight to the fallback
68+
@inline Base.copyto!(dest::GPUDestArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}}) =
69+
copyto!(dest, convert(Broadcasted{Nothing}, bc))
6170

6271
# TODO: is this still necessary?
6372
function mapidx(f, A::GPUArray, args::NTuple{N, Any}) where N

0 commit comments

Comments
 (0)