Skip to content

Commit f546631

Browse files
vchuravymaleadt
authored andcommitted
introduce GPUDestArray to handle ArrayWrappers
1 parent 7f97371 commit f546631

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

src/broadcast.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ BroadcastStyle(::Type{T}) where {T<:GPUArray} = ArrayStyle{T}()
1616
BroadcastStyle(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
1717
BroadcastStyle(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
1818

19+
# This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
20+
# and we could define our methods in terms of Union{GPUArray, WrappedArray{<:Any, <:GPUArray}}
21+
const GPUDestArray = Union{GPUArray,
22+
LinearAlgebra.Transpose{<:Any,<:GPUArray},
23+
LinearAlgebra.Adjoint{<:Any,<:GPUArray}}
24+
1925
# This method is responsible for selection the output type of broadcast
2026
function Base.similar(bc::Broadcasted{<:ArrayStyle{GPU}}, ::Type{ElType}) where {GPU <: GPUArray, ElType}
2127
similar(GPU, ElType, axes(bc))
@@ -32,7 +38,7 @@ end
3238
# Broadcasted based on the output type just at the end of the pipeline.
3339
# - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible with `Style`
3440
# For more information see the Base documentation.
35-
@inline function Base.copyto!(dest::GPUArray, bc::Broadcasted{Nothing})
41+
@inline function Base.copyto!(dest::GPUDestArray, bc::Broadcasted{Nothing})
3642
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
3743
bc′ = Broadcast.preprocess(dest, bc)
3844
gpu_call(dest, (dest, bc′)) do state, dest, bc′
@@ -44,6 +50,12 @@ end
4450
return dest
4551
end
4652

53+
# Base defines this method as a performance optimization, but we don't know how
54+
# to do `fill!` in general for all `GPUDestArray` so we just straight go to the fallback
55+
@inline function Base.copyto!(dest::GPUDestArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}})
56+
return copyto!(dest, convert(Broadcasted{Nothing}, bc))
57+
end
58+
4759
# TODO: is this still necessary?
4860
function mapidx(f, A::GPUArray, args::NTuple{N, Any}) where N
4961
gpu_call(A, (f, A, args)) do state, f, A, args

0 commit comments

Comments
 (0)