Skip to content

Commit 0ed6336

Browse files
authored
Merge pull request #139 from JuliaGPU/vc/cleanup
Broadcast cleanup
2 parents 0479adc + 07c2ec8 commit 0ed6336

File tree

4 files changed

+7
-42
lines changed

4 files changed

+7
-42
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
julia 0.7-alpha
22
StaticArrays
33
FFTW
4+
FillArrays

src/abstractarray.jl

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Base: show, showcompact, similar, convert, _reshape, map!, copyto!, map, copy, deepcopy
1+
import Base: similar, convert, _reshape, map!, copyto!, map, copy, deepcopy
22

33
# Dense GPU Array
44
abstract type GPUArray{T, N} <: DenseArray{T, N} end
@@ -16,16 +16,6 @@ struct LocalMemory{T} <: GPUArray{T, 1}
1616
LocalMemory{T}(x::Integer) where T = new{T}(x)
1717
end
1818

19-
#=
20-
AbstractArray interface
21-
=#
22-
23-
function Base.show(io::IO, A::GPUArray)
24-
print(io, "GPU: ")
25-
Base.show(io, Array(A), repr)
26-
end
27-
28-
2919
############################################
3020
# serialization
3121
import Serialization: AbstractSerializer, serialize, deserialize, serialize_type

src/array.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,9 @@ pointer(x::JLArray) = pointer(x.data)
3030

3131

3232
## I/O
33-
34-
Base.show(io::IO, x::JLArray) = show(io, collect(x))
35-
Base.show(io::IO, x::LinearAlgebra.Adjoint{<:Any,<:JLArray}) = show(io, LinearAlgebra.adjoint(collect(x.parent)))
36-
Base.show(io::IO, x::LinearAlgebra.Transpose{<:Any,<:JLArray}) = show(io, LinearAlgebra.transpose(collect(x.parent)))
37-
38-
Base.show(io::IO, ::MIME"text/plain", x::JLArray) = show(io, MIME"text/plain"(), collect(x))
39-
Base.show(io::IO, ::MIME"text/plain", x::LinearAlgebra.Adjoint{<:Any,<:JLArray}) = show(io, MIME"text/plain"(), LinearAlgebra.adjoint(collect(x.parent)))
40-
Base.show(io::IO, ::MIME"text/plain", x::LinearAlgebra.Transpose{<:Any,<:JLArray}) = show(io, MIME"text/plain"(), LinearAlgebra.transpose(collect(x.parent)))
41-
33+
Base.print_array(io::IO, x::GPUArray) = Base.print_array(io, collect(x))B
34+
Base.print_array(io::IO, x::LinearAlgebra.Adjoint{<:Any,<:GPUArray}) = Base.print_array(io, LinearAlgebra.adjoint(collect(x.parent)))
35+
Base.print_array(io::IO, x::LinearAlgebra.Transpose{<:Any,<:GPUArray}) = Base.print_array(io, LinearAlgebra.transpose(collect(x.parent)))
4236

4337
## other
4438

src/broadcast.jl

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,13 @@
11
using Base.Broadcast
2-
import Base.Broadcast: Broadcasted
32

4-
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcast_axes
5-
import Base.Broadcast: DefaultArrayStyle, materialize!, flatten, ArrayStyle, combine_styles
3+
import Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
64

75
BroadcastStyle(::Type{T}) where T <: GPUArray = ArrayStyle{T}()
8-
BroadcastStyle(::Type{Any}, ::Type{T}) where T <: GPUArray = ArrayStyle{T}()
9-
BroadcastStyle(::Type{T}, ::Type{Any}) where T <: GPUArray = ArrayStyle{T}()
10-
BroadcastStyle(::Type{T1}, ::Type{T2}) where {T1 <: GPUArray, T2 <: GPUArray} = ArrayStyle{T}()
116

12-
const GPUBroadcast = Broadcasted{<: ArrayStyle{<: GPUArray}}
13-
14-
function Base.similar(bc::Broadcasted{ArrayStyle{GPU}}, ::Type{ElType}) where {GPU <: GPUArray, ElType}
7+
function Base.similar(bc::Broadcasted{<:ArrayStyle{GPU}}, ::Type{ElType}) where {GPU <: GPUArray, ElType}
158
similar(GPU, ElType, axes(bc))
169
end
1710

18-
@inline function Base.copyto!(dest::GPUArray, bc::GPUBroadcast)
19-
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
20-
bc′ = Broadcast.preprocess(dest, bc)
21-
gpu_call(dest, (dest, bc′)) do state, dest, bc′
22-
let I = CartesianIndex(@cartesianidx(dest))
23-
@inbounds dest[I] = bc′[I]
24-
end
25-
end
26-
27-
return dest
28-
end
29-
30-
# the same?
3111
@inline function Base.copyto!(dest::GPUArray, bc::Broadcasted{Nothing})
3212
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
3313
bc′ = Broadcast.preprocess(dest, bc)

0 commit comments

Comments
 (0)