@@ -16,6 +16,12 @@ BroadcastStyle(::Type{T}) where {T<:GPUArray} = ArrayStyle{T}()
16
16
BroadcastStyle (:: Type{<:LinearAlgebra.Transpose{<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
17
17
BroadcastStyle (:: Type{<:LinearAlgebra.Adjoint{<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
18
18
19
+ # This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
20
+ # and we could define our methods in terms of Union{GPUArray, WrappedArray{<:Any, <:GPUArray}}
21
+ const GPUDestArray = Union{GPUArray,
22
+ LinearAlgebra. Transpose{<: Any ,<: GPUArray },
23
+ LinearAlgebra. Adjoint{<: Any ,<: GPUArray }}
24
+
19
25
# This method is responsible for selection the output type of broadcast
20
26
function Base. similar (bc:: Broadcasted{<:ArrayStyle{GPU}} , :: Type{ElType} ) where {GPU <: GPUArray , ElType}
21
27
similar (GPU, ElType, axes (bc))
32
38
# Broadcasted based on the output type just at the end of the pipeline.
33
39
# - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible with `Style`
34
40
# For more information see the Base documentation.
35
- @inline function Base. copyto! (dest:: GPUArray , bc:: Broadcasted{Nothing} )
41
+ @inline function Base. copyto! (dest:: GPUDestArray , bc:: Broadcasted{Nothing} )
36
42
axes (dest) == axes (bc) || Broadcast. throwdm (axes (dest), axes (bc))
37
43
bc′ = Broadcast. preprocess (dest, bc)
38
44
gpu_call (dest, (dest, bc′)) do state, dest, bc′
44
50
return dest
45
51
end
46
52
53
+ # Base defines this method as a performance optimization, but we don't know how
54
+ # to do `fill!` in general for all `GPUDestArray` so we just straight go to the fallback
55
+ @inline function Base. copyto! (dest:: GPUDestArray , bc:: Broadcasted{<:Broadcast.AbstractArrayStyle{0}} )
56
+ return copyto! (dest, convert (Broadcasted{Nothing}, bc))
57
+ end
58
+
47
59
# TODO : is this still necessary?
48
60
function mapidx (f, A:: GPUArray , args:: NTuple{N, Any} ) where N
49
61
gpu_call (A, (f, A, args)) do state, f, A, args
0 commit comments