Skip to content

Commit 43746d7

Browse files
authored
Merge pull request #214 from JuliaGPU/tb/adapt_collect
Use Adapt.jl for generating collect methods.
2 parents 54de102 + 42a4d3b commit 43746d7

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

src/abstractarray.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,36 @@ function to_cartesian(A, indices::Tuple)
5050
CartesianIndices(start, stop)
5151
end
5252

53-
## showing
53+
## convert to CPU (keeping wrapper type)
5454

5555
Adapt.adapt_storage(::Type{<:Array}, xs::AbstractArray) = convert(Array, xs)
56-
cpu(xs) = adapt(Array, xs)
56+
convert_to_cpu(xs) = adapt(Array, xs)
57+
58+
## showing
5759

5860
for (W, ctor) in (:AT => (A,mut)->mut(A), Adapt.wrappers...)
5961
@eval begin
6062
# display
61-
Base.print_array(io::IO, X::$W where {AT <: GPUArray}) = Base.print_array(io, $ctor(X, cpu))
63+
Base.print_array(io::IO, X::$W where {AT <: GPUArray}) =
64+
Base.print_array(io, $ctor(X, convert_to_cpu))
6265

6366
# show
6467
Base._show_nonempty(io::IO, X::$W where {AT <: GPUArray}, prefix::String) =
65-
Base._show_nonempty(io, $ctor(X, cpu), prefix)
68+
Base._show_nonempty(io, $ctor(X, convert_to_cpu), prefix)
6669
Base._show_empty(io::IO, X::$W where {AT <: GPUArray}) =
67-
Base._show_empty(io, $ctor(X, cpu))
70+
Base._show_empty(io, $ctor(X, convert_to_cpu))
6871
Base.show_vector(io::IO, v::$W where {AT <: GPUArray}, args...) =
69-
Base.show_vector(io, $ctor(v, cpu), args...)
72+
Base.show_vector(io, $ctor(v, convert_to_cpu), args...)
73+
end
74+
end
75+
76+
## collect to CPU (discarding wrapper type)
77+
78+
collect_to_cpu(xs::AbstractArray) = collect(convert_to_cpu(xs))
79+
80+
for (W, ctor) in (:AT => (A,mut)->mut(A), Adapt.wrappers...)
81+
@eval begin
82+
Base.collect(X::$W where {AT <: GPUArray}) = collect_to_cpu(X)
7083
end
7184
end
7285

src/mapreduce.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,11 @@ for i = 0:10
166166

167167
end
168168

169-
to_cpu(x) = x
170-
to_cpu(x::GPUArray) = Array(x)
171-
to_cpu(x::Broadcasted{ArrayStyle{AT}}) where {AT <: GPUArray} = to_cpu(Base.Broadcast.materialize(x))
172-
to_cpu(x::LinearAlgebra.Transpose) = LinearAlgebra.Transpose(to_cpu(parent(x)))
173-
to_cpu(x::LinearAlgebra.Adjoint) = LinearAlgebra.Adjoint(to_cpu(parent(x)))
174-
to_cpu(x::SubArray) = SubArray(to_cpu(parent(x)), parentindices(x))
175-
176169
function acc_mapreduce(f, op, v0::OT, A::GPUSrcArray, rest::Tuple) where {OT}
177170
blocksize = 80
178171
threads = 256
179172
if length(A) <= blocksize * threads
180-
args = zip(to_cpu(A), to_cpu.(rest)...)
173+
args = zip(convert_to_cpu(A), convert_to_cpu.(rest)...)
181174
return mapreduce(x-> f(x...), op, args, init = v0)
182175
end
183176
out = similar(A, OT, (blocksize,))

0 commit comments

Comments
 (0)