|
59 | 59 | @test_throws Exception batched_mul!(zeros(2,2,10), rand(2,2,2), rand(TB, 2,2,2))
|
60 | 60 |
|
61 | 61 | end
|
| 62 | + |
| 63 | +@testset "BatchedAdjOrTrans interface * $TB" for TB in [Float64, Float32] |
| 64 | + A = randn(7,5,3) |
| 65 | + B = randn(TB, 5,7,3) |
| 66 | + C = randn(7,6,3) |
| 67 | + |
| 68 | + function interface_tests(X, _X) |
| 69 | + @test length(_X) == length(X) |
| 70 | + @test size(_X) == (size(X, 2), size(X, 1), size(X, 3)) |
| 71 | + @test axes(_X) == (axes(X, 2), axes(X, 1), axes(X, 3)) |
| 72 | + # |
| 73 | + @test getindex(_X, 2, 3, 3) == getindex(X, 3, 2, 3) |
| 74 | + @test getindex(_X, 5, 4, 1) == getindex(X, 4, 5, 1) |
| 75 | + # |
| 76 | + setindex!(_X, 2.0, 2, 4, 1) |
| 77 | + @test getindex(_X, 2, 4, 1) == 2.0 |
| 78 | + setindex!(_X, 3.0, 1, 2, 2) |
| 79 | + @test getindex(_X, 1, 2, 2) == 3.0 |
| 80 | + |
| 81 | + _sim = similar(_X, TB, (2, 3)) |
| 82 | + @test size(_sim) == (2, 3) |
| 83 | + @test typeof(_sim) == Array{TB, 2} |
| 84 | + |
| 85 | + _sim = similar(_X, TB) |
| 86 | + @test length(_sim) == length(_X) |
| 87 | + @test typeof(_sim) == Array{TB, 3} |
| 88 | + |
| 89 | + _sim = similar(_X, (2, 3)) |
| 90 | + @test size(_sim) == (2, 3) |
| 91 | + @test typeof(_sim) == Array{Float64, 2} |
| 92 | + |
| 93 | + _sim = similar(_X) |
| 94 | + @test length(_sim) == length(_X) |
| 95 | + @test typeof(_sim) == Array{Float64, 3} |
| 96 | + |
| 97 | + @test parent(_X) == _X.parent |
| 98 | + end |
| 99 | + |
| 100 | + for (X, _X) in zip([A, B, C], map(batched_adjoint, [A, B, C])) |
| 101 | + interface_tests(X, _X) |
| 102 | + |
| 103 | + @test -_X == NNlib.BatchedAdjoint(-_X.parent) |
| 104 | + |
| 105 | + _copyX = copy(_X) |
| 106 | + @test _X == _copyX |
| 107 | + |
| 108 | + setindex!(_copyX, 2.0, 1, 2, 1) |
| 109 | + @test _X != _copyX |
| 110 | + end |
| 111 | + |
| 112 | + for (X, _X) in zip([A, B, C], map(batched_transpose, [A, B, C])) |
| 113 | + interface_tests(X, _X) |
| 114 | + |
| 115 | + @test -_X == NNlib.BatchedTranspose(-_X.parent) |
| 116 | + |
| 117 | + _copyX = copy(_X) |
| 118 | + @test _X == _copyX |
| 119 | + |
| 120 | + setindex!(_copyX, 2.0, 1, 2, 1) |
| 121 | + @test _X != _copyX |
| 122 | + end |
| 123 | +end |
0 commit comments