Skip to content

Commit c835047

Browse files
Merge pull request #196 from sudo-rushil/sr/batched
BatchedAdjOrTrans interface tests
2 parents a0fbad5 + dcfdd18 commit c835047

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

test/batchedmul.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,65 @@ end
5959
@test_throws Exception batched_mul!(zeros(2,2,10), rand(2,2,2), rand(TB, 2,2,2))
6060

6161
end
62+
63+
@testset "BatchedAdjOrTrans interface * $TB" for TB in [Float64, Float32]
64+
A = randn(7,5,3)
65+
B = randn(TB, 5,7,3)
66+
C = randn(7,6,3)
67+
68+
function interface_tests(X, _X)
69+
@test length(_X) == length(X)
70+
@test size(_X) == (size(X, 2), size(X, 1), size(X, 3))
71+
@test axes(_X) == (axes(X, 2), axes(X, 1), axes(X, 3))
72+
#
73+
@test getindex(_X, 2, 3, 3) == getindex(X, 3, 2, 3)
74+
@test getindex(_X, 5, 4, 1) == getindex(X, 4, 5, 1)
75+
#
76+
setindex!(_X, 2.0, 2, 4, 1)
77+
@test getindex(_X, 2, 4, 1) == 2.0
78+
setindex!(_X, 3.0, 1, 2, 2)
79+
@test getindex(_X, 1, 2, 2) == 3.0
80+
81+
_sim = similar(_X, TB, (2, 3))
82+
@test size(_sim) == (2, 3)
83+
@test typeof(_sim) == Array{TB, 2}
84+
85+
_sim = similar(_X, TB)
86+
@test length(_sim) == length(_X)
87+
@test typeof(_sim) == Array{TB, 3}
88+
89+
_sim = similar(_X, (2, 3))
90+
@test size(_sim) == (2, 3)
91+
@test typeof(_sim) == Array{Float64, 2}
92+
93+
_sim = similar(_X)
94+
@test length(_sim) == length(_X)
95+
@test typeof(_sim) == Array{Float64, 3}
96+
97+
@test parent(_X) == _X.parent
98+
end
99+
100+
for (X, _X) in zip([A, B, C], map(batched_adjoint, [A, B, C]))
101+
interface_tests(X, _X)
102+
103+
@test -_X == NNlib.BatchedAdjoint(-_X.parent)
104+
105+
_copyX = copy(_X)
106+
@test _X == _copyX
107+
108+
setindex!(_copyX, 2.0, 1, 2, 1)
109+
@test _X != _copyX
110+
end
111+
112+
for (X, _X) in zip([A, B, C], map(batched_transpose, [A, B, C]))
113+
interface_tests(X, _X)
114+
115+
@test -_X == NNlib.BatchedTranspose(-_X.parent)
116+
117+
_copyX = copy(_X)
118+
@test _X == _copyX
119+
120+
setindex!(_copyX, 2.0, 1, 2, 1)
121+
@test _X != _copyX
122+
end
123+
end

0 commit comments

Comments
 (0)