@@ -23,29 +23,29 @@ function bmm_adjtest(a,b; adjA = false, adjB = false)
23
23
end
24
24
25
25
26
- @testset " Batched Matrix Multiplication " for TB in [Float64, Float32]
26
+ @testset " Batched Matrices: Float64 * $TB " for TB in [Float64, Float32]
27
27
28
28
A = randn (7 ,5 ,3 )
29
29
B = randn (TB, 5 ,7 ,3 )
30
30
C = randn (7 ,6 ,3 )
31
31
32
32
@test batched_mul (A, B) == bmm_test (A, B)
33
33
@test batched_mul (batched_transpose (A), batched_transpose (B)) == bmm_test (A, B; transA = true , transB = true )
34
- @test batched_mul (batched_transpose (A), C) == bmm_test (A, C; transA = true )
35
- @test batched_mul (A, batched_transpose (A)) == bmm_test (A, A; transB = true )
34
+ @test batched_mul (batched_transpose (A), C) ≈ bmm_test (A, C; transA = true )
35
+ @test batched_mul (A, batched_transpose (A)) ≈ bmm_test (A, A; transB = true )
36
36
37
37
38
38
cA = randn (Complex{Float64}, 7 ,5 ,3 )
39
39
cB = randn (Complex{TB}, 5 ,7 ,3 )
40
40
cC = randn (Complex{Float64}, 7 ,6 ,3 )
41
41
42
42
@test batched_mul (cA, cB) == bmm_adjtest (cA, cB)
43
- @test batched_mul (batched_adjoint (cA), batched_adjoint (cB)) == bmm_adjtest (cA, cB; adjA = true , adjB = true )
44
- @test batched_mul (batched_adjoint (cA), cC) == bmm_adjtest (cA, cC; adjA = true )
45
- @test batched_mul (cA, batched_adjoint (cA)) == bmm_adjtest (cA, cA; adjB = true )
43
+ @test batched_mul (batched_adjoint (cA), batched_adjoint (cB)) ≈ bmm_adjtest (cA, cB; adjA = true , adjB = true )
44
+ @test batched_mul (batched_adjoint (cA), cC) ≈ bmm_adjtest (cA, cC; adjA = true )
45
+ @test batched_mul (cA, batched_adjoint (cA)) ≈ bmm_adjtest (cA, cA; adjB = true )
46
46
47
- @test batched_transpose (batched_transpose (A)) == A
48
- @test batched_adjoint (batched_adjoint (cA)) == cA
47
+ @test batched_transpose (batched_transpose (A)) === A
48
+ @test batched_adjoint (batched_adjoint (cA)) === cA
49
49
50
50
TBi = TB== Float64 ? Int64 : Int32
51
51
iA = rand (1 : 99 , 7 ,5 ,3 )
0 commit comments