Skip to content

Commit 6c1bdff

Browse files
committed
Generate methods for array wrappers using Adapt.jl.
1 parent 38ccea2 commit 6c1bdff

File tree

3 files changed

+23
-36
lines changed

3 files changed

+23
-36
lines changed

src/abstractarray.jl

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,29 +52,19 @@ end
5252

5353
## showing
5454

55-
for (AT, f) in
56-
(GPUArray => Array,
57-
SubArray{<:Any,<:Any,<:GPUArray} => x->SubArray(Array(parent(x)), parentindices(x)),
58-
LinearAlgebra.Adjoint{<:Any,<:GPUArray} => x->LinearAlgebra.adjoint(Array(parent(x))),
59-
LinearAlgebra.Transpose{<:Any,<:GPUArray} => x->LinearAlgebra.transpose(Array(parent(x))),
60-
LinearAlgebra.LowerTriangular{<:Any,<:GPUArray} => x->LinearAlgebra.LowerTriangular(Array(x.data)),
61-
LinearAlgebra.UnitLowerTriangular{<:Any,<:GPUArray} => x->LinearAlgebra.UnitLowerTriangular(Array(x.data)),
62-
LinearAlgebra.UpperTriangular{<:Any,<:GPUArray} => x->LinearAlgebra.UpperTriangular(Array(x.data)),
63-
LinearAlgebra.UnitUpperTriangular{<:Any,<:GPUArray} => x->LinearAlgebra.UnitUpperTriangular(Array(x.data))
64-
)
65-
@eval begin
66-
# for display
67-
Base.print_array(io::IO, X::$AT) =
68-
Base.print_array(io,$f(X))
69-
70-
# for show
71-
Base._show_nonempty(io::IO, X::$AT, prefix::String) =
72-
Base._show_nonempty(io,$f(X),prefix)
73-
Base._show_empty(io::IO, X::$AT) =
74-
Base._show_empty(io,$f(X))
75-
Base.show_vector(io::IO, v::$AT, args...) =
76-
Base.show_vector(io,$f(v),args...)
77-
end
55+
for (W, ctor) in (:AT => (A,mut)->mut(A), Adapt.wrappers...)
56+
@eval begin
57+
# display
58+
Base.print_array(io::IO, X::$W where {AT <: GPUArray}) = Base.print_array(io, $ctor(X, Array))
59+
60+
# show
61+
Base._show_nonempty(io::IO, X::$W where {AT <: GPUArray}, prefix::String) =
62+
Base._show_nonempty(io, $ctor(X, Array), prefix)
63+
Base._show_empty(io::IO, X::$W where {AT <: GPUArray}) =
64+
Base._show_empty(io, $ctor(X, Array))
65+
Base.show_vector(io::IO, v::$W where {AT <: GPUArray}, args...) =
66+
Base.show_vector(io, $ctor(v, Array), args...)
67+
end
7868
end
7969

8070
# memory operations

src/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ end
145145
function LocalMemory(state::JLState, ::Type{T}, ::Val{N}, ::Val{C}) where {T, N, C}
146146
state.localmem_counter += 1
147147
lmems = state.localmems[blockidx_x(state)]
148-
# first invokation in block
148+
# first invocation in block
149149
if length(lmems) < state.localmem_counter
150150
lmem = fill(zero(T), N)
151151
push!(lmems, lmem)

src/broadcast.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,21 @@ import Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
1212
# instead of using `ArrayStyle{GPUArray}`, due to the fact how `similar` works.
1313
BroadcastStyle(::Type{T}) where {T<:GPUArray} = ArrayStyle{T}()
1414

15-
# These wrapper types otherwise forget that they are GPU compatible
15+
# Wrapper types otherwise forget that they are GPU compatible
1616
#
1717
# NOTE: Don't directly use ArrayStyle{GPUArray} here since that would mean that `CuArrays`
1818
# customization no longer take effect.
19-
BroadcastStyle(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
20-
BroadcastStyle(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
21-
BroadcastStyle(::Type{<:SubArray{<:Any,<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
22-
23-
backend(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where {T<:GPUArray} = backend(T)
24-
backend(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where {T<:GPUArray} = backend(T)
25-
backend(::Type{<:SubArray{<:Any,<:Any,T}}) where {T<:GPUArray} = backend(T)
19+
for (W, ctor) in Adapt.wrappers
20+
@eval begin
21+
BroadcastStyle(::Type{<:$W}) where {AT<:GPUArray} = BroadcastStyle(AT)
22+
backend(::Type{<:$W}) where {AT<:GPUArray} = backend(AT)
23+
end
24+
end
2625

2726
# This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
2827
# and we could define our methods in terms of Union{GPUArray, WrappedArray{<:Any, <:GPUArray}}
29-
const GPUDestArray = Union{GPUArray,
30-
LinearAlgebra.Transpose{<:Any,<:GPUArray},
31-
LinearAlgebra.Adjoint{<:Any,<:GPUArray},
32-
SubArray{<:Any,<:Any,<:GPUArray}}
28+
@eval const GPUDestArray =
29+
Union{GPUArray, $((:($W where {AT <: GPUArray}) for (W, _) in Adapt.wrappers)...)}
3330

3431
# We purposefully only specialize `copyto!`, dependent packages need to make sure that they
3532
# can handle:

0 commit comments

Comments
 (0)