Skip to content

Commit 7953c92

Browse files
print/convert batchedadjtrans over cuarray (#54)
* print/convert batchedadjtrans over cuarray * Update test/batchedadjtrans.jl Co-authored-by: Brian Chen <[email protected]> Co-authored-by: Brian Chen <[email protected]>
1 parent d29ab6e commit 7953c92

File tree

5 files changed

+65
-0
lines changed

5 files changed

+65
-0
lines changed

ext/NNlibCUDA/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
33
version = "0.2.3"
44

55
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
67
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1112

1213
[compat]
14+
Adapt = "3.3"
1315
CUDA = "3.11"
1416
NNlib = "0.8.7"
1517
julia = "1.6"

ext/NNlibCUDA/src/NNlibCUDA.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N}
99
include("upsample.jl")
1010
include("sampling.jl")
1111
include("activations.jl")
12+
include("batchedadjtrans.jl")
1213
include("batchedmul.jl")
1314
include("scatter.jl")
1415
include("gather.jl")

ext/NNlibCUDA/src/batchedadjtrans.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans
2+
using Adapt
3+
using Adapt: WrappedArray
4+
5+
const CuBatchedAdjoint{T} = BatchedAdjoint{T, <: CuArray{T}}
6+
const CuBatchedTranspose{T} = BatchedTranspose{T, <: CuArray{T}}
7+
const CuBatchedAdjOrTrans{T} = Union{CuBatchedAdjoint{T}, CuBatchedTranspose{T}}
8+
const WrappedCuBatchedAdjOrTrans{T, N} = WrappedArray{T, N, CuBatchedAdjOrTrans{T}, CuBatchedAdjOrTrans{T}}
9+
10+
11+
Base.print_array(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) = Base.print_array(io, adapt(Array, b))
12+
Base._show_nonempty(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}, prefix::String) = Base._show_nonempty(io, adapt(Array, b), prefix)
13+
Base.show_vector(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}, opn, cls) = Base.show_vector(io, adapt(Array, b), opn, cls)
14+
15+
Base.convert(::Type{T}, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) where {T<:Array} = Base.convert(T, adapt(Array, b))
16+
Base.Array{T, N}(b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) where {T, N} = Array{T, N}(adapt(Array, b))
17+
Base.collect(b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) = collect(adapt(Array, b))

ext/NNlibCUDA/test/batchedadjtrans.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
function print_array_strs(x)
2+
str = sprint((io, x)->show(io, MIME"text/plain"(), x), x)
3+
return @view split(str, '\n')[2:end]
4+
end
5+
6+
@testset "BatchedAdjOrTrans" begin
7+
x = randn(Float32, 3,4,2)
8+
y = cu(x)
9+
10+
bax = batched_adjoint(x)
11+
btx = batched_transpose(x)
12+
bay = batched_adjoint(y)
13+
bty = batched_transpose(y)
14+
15+
@test sprint(show, bax) == sprint(show, bay)
16+
@test sprint(show, btx) == sprint(show, bty)
17+
18+
@test print_array_strs(bax) == print_array_strs(bay)
19+
@test print_array_strs(btx) == print_array_strs(bty)
20+
21+
@test Array(bax) == Array(bay)
22+
@test collect(bax) == collect(bay)
23+
@test Array(btx) == Array(bty)
24+
@test collect(btx) == collect(bty)
25+
26+
for shape in (:, (12, 2))
27+
rbax = reshape(bax, shape)
28+
rbtx = reshape(btx, shape)
29+
rbay = reshape(bay, shape)
30+
rbty = reshape(bty, shape)
31+
32+
@test sprint(show, rbax) == sprint(show, rbay)
33+
@test sprint(show, rbtx) == sprint(show, rbty)
34+
35+
@test print_array_strs(rbax) == print_array_strs(rbay)
36+
@test print_array_strs(rbtx) == print_array_strs(rbty)
37+
38+
@test Array(rbax) == Array(rbay)
39+
@test collect(rbax) == collect(rbay)
40+
@test Array(rbtx) == Array(rbty)
41+
@test collect(rbtx) == collect(rbty)
42+
end
43+
44+
end

ext/NNlibCUDA/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ CUDA.allowscalar(false)
1010
@testset "NNlibCUDA" begin
1111
include("test_utils.jl")
1212
include("activations.jl")
13+
include("batchedadjtrans.jl")
1314
include("batchedmul.jl")
1415
include("upsample.jl")
1516
include("conv.jl")

0 commit comments

Comments
 (0)