Skip to content

Commit 8988ac4

Browse files
committed
cleanup broadcast implementation
1 parent 0479adc commit 8988ac4

File tree

1 file changed

+2
-22
lines changed

1 file changed

+2
-22
lines changed

src/broadcast.jl

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,13 @@
11
using Base.Broadcast
2-
import Base.Broadcast: Broadcasted
32

4-
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcast_axes
5-
import Base.Broadcast: DefaultArrayStyle, materialize!, flatten, ArrayStyle, combine_styles
3+
import Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
64

75
BroadcastStyle(::Type{T}) where T <: GPUArray = ArrayStyle{T}()
8-
BroadcastStyle(::Type{Any}, ::Type{T}) where T <: GPUArray = ArrayStyle{T}()
9-
BroadcastStyle(::Type{T}, ::Type{Any}) where T <: GPUArray = ArrayStyle{T}()
10-
BroadcastStyle(::Type{T1}, ::Type{T2}) where {T1 <: GPUArray, T2 <: GPUArray} = ArrayStyle{T}()
116

12-
const GPUBroadcast = Broadcasted{<: ArrayStyle{<: GPUArray}}
13-
14-
function Base.similar(bc::Broadcasted{ArrayStyle{GPU}}, ::Type{ElType}) where {GPU <: GPUArray, ElType}
7+
function Base.similar(bc::Broadcasted{<:ArrayStyle{GPU}}, ::Type{ElType}) where {GPU <: GPUArray, ElType}
158
similar(GPU, ElType, axes(bc))
169
end
1710

18-
@inline function Base.copyto!(dest::GPUArray, bc::GPUBroadcast)
19-
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
20-
bc′ = Broadcast.preprocess(dest, bc)
21-
gpu_call(dest, (dest, bc′)) do state, dest, bc′
22-
let I = CartesianIndex(@cartesianidx(dest))
23-
@inbounds dest[I] = bc′[I]
24-
end
25-
end
26-
27-
return dest
28-
end
29-
30-
# the same?
3111
@inline function Base.copyto!(dest::GPUArray, bc::Broadcasted{Nothing})
3212
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
3313
bc′ = Broadcast.preprocess(dest, bc)

0 commit comments

Comments
 (0)