Skip to content

Commit fd591b7

Browse files
authored
Merge pull request #172 from JuliaGPU/tb/wrappers
Use Adapt.jl to generate array wrapper methods
2 parents 6046c73 + 087b9fc commit fd591b7

File tree

7 files changed

+46
-62
lines changed

7 files changed

+46
-62
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ name = "GPUArrays"
22
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
33

44
[deps]
5+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
56
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
67
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ julia 1.0
22
StaticArrays
33
FFTW
44
FillArrays 0.3
5+
Adapt 0.4.1

src/GPUArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ using Base.Cartesian
1414

1515
using FFTW
1616

17+
using Adapt
18+
1719
include("abstractarray.jl")
1820
include("abstract_gpu_interface.jl")
1921
include("ondevice.jl")

src/abstractarray.jl

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,29 +52,19 @@ end
5252

5353
## showing
5454

55-
for (AT, f) in
56-
(GPUArray => Array,
57-
SubArray{<:Any,<:Any,<:GPUArray} => x->SubArray(Array(parent(x)), parentindices(x)),
58-
LinearAlgebra.Adjoint{<:Any,<:GPUArray} => x->LinearAlgebra.adjoint(Array(parent(x))),
59-
LinearAlgebra.Transpose{<:Any,<:GPUArray} => x->LinearAlgebra.transpose(Array(parent(x))),
60-
LinearAlgebra.LowerTriangular{<:Any,<:GPUArray} => x->LinearAlgebra.LowerTriangular(Array(x.data)),
61-
LinearAlgebra.UnitLowerTriangular{<:Any,<:GPUArray} => x->LinearAlgebra.UnitLowerTriangular(Array(x.data)),
62-
LinearAlgebra.UpperTriangular{<:Any,<:GPUArray} => x->LinearAlgebra.UpperTriangular(Array(x.data)),
63-
LinearAlgebra.UnitUpperTriangular{<:Any,<:GPUArray} => x->LinearAlgebra.UnitUpperTriangular(Array(x.data))
64-
)
65-
@eval begin
66-
# for display
67-
Base.print_array(io::IO, X::$AT) =
68-
Base.print_array(io,$f(X))
69-
70-
# for show
71-
Base._show_nonempty(io::IO, X::$AT, prefix::String) =
72-
Base._show_nonempty(io,$f(X),prefix)
73-
Base._show_empty(io::IO, X::$AT) =
74-
Base._show_empty(io,$f(X))
75-
Base.show_vector(io::IO, v::$AT, args...) =
76-
Base.show_vector(io,$f(v),args...)
77-
end
55+
for (W, ctor) in (:AT => (A,mut)->mut(A), Adapt.wrappers...)
56+
@eval begin
57+
# display
58+
Base.print_array(io::IO, X::$W where {AT <: GPUArray}) = Base.print_array(io, $ctor(X, Array))
59+
60+
# show
61+
Base._show_nonempty(io::IO, X::$W where {AT <: GPUArray}, prefix::String) =
62+
Base._show_nonempty(io, $ctor(X, Array), prefix)
63+
Base._show_empty(io::IO, X::$W where {AT <: GPUArray}) =
64+
Base._show_empty(io, $ctor(X, Array))
65+
Base.show_vector(io::IO, v::$W where {AT <: GPUArray}, args...) =
66+
Base.show_vector(io, $ctor(v, Array), args...)
67+
end
7868
end
7969

8070
# memory operations

src/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ end
145145
function LocalMemory(state::JLState, ::Type{T}, ::Val{N}, ::Val{C}) where {T, N, C}
146146
state.localmem_counter += 1
147147
lmems = state.localmems[blockidx_x(state)]
148-
# first invokation in block
148+
# first invocation in block
149149
if length(lmems) < state.localmem_counter
150150
lmem = fill(zero(T), N)
151151
push!(lmems, lmem)

src/broadcast.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,21 @@ import Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
1212
# instead of using `ArrayStyle{GPUArray}`, due to the fact how `similar` works.
1313
BroadcastStyle(::Type{T}) where {T<:GPUArray} = ArrayStyle{T}()
1414

15-
# These wrapper types otherwise forget that they are GPU compatible
15+
# Wrapper types otherwise forget that they are GPU compatible
1616
#
1717
# NOTE: Don't directly use ArrayStyle{GPUArray} here since that would mean that `CuArrays`
1818
# customization no longer take effect.
19-
BroadcastStyle(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
20-
BroadcastStyle(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
21-
BroadcastStyle(::Type{<:SubArray{<:Any,<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
22-
23-
backend(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where {T<:GPUArray} = backend(T)
24-
backend(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where {T<:GPUArray} = backend(T)
25-
backend(::Type{<:SubArray{<:Any,<:Any,T}}) where {T<:GPUArray} = backend(T)
19+
for (W, ctor) in Adapt.wrappers
20+
@eval begin
21+
BroadcastStyle(::Type{<:$W}) where {AT<:GPUArray} = BroadcastStyle(AT)
22+
backend(::Type{<:$W}) where {AT<:GPUArray} = backend(AT)
23+
end
24+
end
2625

2726
# This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
2827
# and we could define our methods in terms of Union{GPUArray, WrappedArray{<:Any, <:GPUArray}}
29-
const GPUDestArray = Union{GPUArray,
30-
LinearAlgebra.Transpose{<:Any,<:GPUArray},
31-
LinearAlgebra.Adjoint{<:Any,<:GPUArray},
32-
SubArray{<:Any,<:Any,<:GPUArray}}
28+
@eval const GPUDestArray =
29+
Union{GPUArray, $((:($W where {AT <: GPUArray}) for (W, _) in Adapt.wrappers)...)}
3330

3431
# We purposefully only specialize `copyto!`, dependent packages need to make sure that they
3532
# can handle:

test/runtests.jl

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,29 @@
1-
using GPUArrays, Test
2-
using GPUArrays.TestSuite
3-
1+
# GPUArrays development often happens in lockstep with other packages, so try to match branches
42
using Pkg
5-
6-
@testset "JLArray" begin
7-
GPUArrays.test(JLArray)
3+
function match_package(package, var)
4+
try
5+
branch = ENV[var]
6+
Pkg.add(PackageSpec(name=package, rev=branch))
7+
@info "Installed $package from branch $branch"
8+
catch ex
9+
@warn "Could not install $package from $branch branch, trying master" exception=ex
10+
Pkg.add(PackageSpec(name=package, rev="master"))
11+
@info "Installed $package from master"
12+
end
813
end
14+
haskey(ENV, "TRAVIS") && match_package("Adapt", "TRAVIS_PULL_REQUEST_BRANCH")
15+
haskey(ENV, "APPVEYOR") && match_package("Adapt", "APPVEYOR_PULL_REQUEST_HEAD_REPO_BRANCH")
16+
haskey(ENV, "GITLAB_CI") && match_package("Adapt", "CI_COMMIT_REF_NAME")
917

10-
function test_package(package, branch=nothing)
11-
mktempdir() do devdir
12-
withenv("JULIA_PKG_DEVDIR" => devdir) do
13-
# try to install from the same branch of GPUArrays
14-
try
15-
if branch === nothing
16-
branch = chomp(read(`git -C $(@__DIR__) rev-parse --abbrev-ref HEAD`, String))
17-
branch == "HEAD" && error("in detached HEAD state")
18-
end
19-
Pkg.add(PackageSpec(name=package, rev=String(branch)))
20-
@info "Installed $package from $branch branch"
21-
catch ex
22-
@warn "Could not install $package from same branch as GPUArrays, trying master branch" exception=ex
23-
Pkg.add(PackageSpec(name=package, rev="master"))
24-
end
18+
using GPUArrays, Test
2519

26-
Pkg.test(package)
27-
end
28-
end
20+
@testset "JLArray" begin
21+
GPUArrays.test(JLArray)
2922
end
3023

3124
if haskey(ENV, "GITLAB_CI")
32-
branch = ENV["CI_COMMIT_REF_NAME"]
25+
match_package("CuArrays", "CI_COMMIT_REF_NAME")
3326
@testset "CuArray" begin
34-
test_package("CuArrays", branch)
27+
Pkg.test("CuArrays")
3528
end
3629
end

0 commit comments

Comments
 (0)