|
1 | 1 | using Base.Broadcast
|
2 |
| -import Base.Broadcast: Broadcasted |
3 | 2 |
|
4 |
| -import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcast_axes |
5 |
| -import Base.Broadcast: DefaultArrayStyle, materialize!, flatten, ArrayStyle, combine_styles |
| 3 | +import Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle |
6 | 4 |
|
7 | 5 | BroadcastStyle(::Type{T}) where T <: GPUArray = ArrayStyle{T}()
|
8 |
| -BroadcastStyle(::Type{Any}, ::Type{T}) where T <: GPUArray = ArrayStyle{T}() |
9 |
| -BroadcastStyle(::Type{T}, ::Type{Any}) where T <: GPUArray = ArrayStyle{T}() |
10 |
| -BroadcastStyle(::Type{T1}, ::Type{T2}) where {T1 <: GPUArray, T2 <: GPUArray} = ArrayStyle{T}() |
11 | 6 |
|
12 |
| -const GPUBroadcast = Broadcasted{<: ArrayStyle{<: GPUArray}} |
13 |
| - |
14 |
| -function Base.similar(bc::Broadcasted{ArrayStyle{GPU}}, ::Type{ElType}) where {GPU <: GPUArray, ElType} |
| 7 | +function Base.similar(bc::Broadcasted{<:ArrayStyle{GPU}}, ::Type{ElType}) where {GPU <: GPUArray, ElType} |
15 | 8 | similar(GPU, ElType, axes(bc))
|
16 | 9 | end
|
17 | 10 |
|
18 |
| -@inline function Base.copyto!(dest::GPUArray, bc::GPUBroadcast) |
19 |
| - axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) |
20 |
| - bc′ = Broadcast.preprocess(dest, bc) |
21 |
| - gpu_call(dest, (dest, bc′)) do state, dest, bc′ |
22 |
| - let I = CartesianIndex(@cartesianidx(dest)) |
23 |
| - @inbounds dest[I] = bc′[I] |
24 |
| - end |
25 |
| - end |
26 |
| - |
27 |
| - return dest |
28 |
| -end |
29 |
| - |
30 |
| -# the same? |
31 | 11 | @inline function Base.copyto!(dest::GPUArray, bc::Broadcasted{Nothing})
|
32 | 12 | axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
|
33 | 13 | bc′ = Broadcast.preprocess(dest, bc)
|
|
0 commit comments