Skip to content

Commit 876b45c

Browse files
committed
Improve the tests to cover 3-arg and 5-arg mul!
1 parent c3b3a22 commit 876b45c

File tree

1 file changed

+13
-20
lines changed

1 file changed

+13
-20
lines changed

test/testsuite/linalg.jl

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -61,28 +61,21 @@ function test_linalg(AT)
6161
@test f(A, d) == Array(f!(AT(A), d))
6262
end
6363

64-
@testset "matrix multiplication" begin
65-
for (a,b) in [((3,4),(4,3)), ((3,), (1,3)), ((1,3), (3))], T in supported_eltypes()
66-
@test compare(*, AT, rand(T, a), rand(T, b))
67-
68-
if length(a) > 1
69-
@test compare(*, AT, transpose(rand(T, reverse(a))), rand(T, b))
70-
@test compare(*, AT, adjoint(rand(T, reverse(a))), rand(T, b))
71-
end
72-
73-
if length(b) > 1
74-
@test compare(*, AT, rand(T, a), transpose(rand(T, reverse(b))))
75-
@test compare(*, AT, rand(T, a), adjoint(rand(T, reverse(b))))
76-
end
64+
@testset "$T gemv y := $f(A) * x * a + y * b" for f in (identity, transpose, adjoint), T in supported_eltypes()
65+
@test compare(*, AT, f(rand(T, 4, 4)), rand(T, 4))
66+
@test compare(mul!, AT, rand(T, 4), f(rand(T, 4, 4)), rand(T, 4))
67+
@test compare(mul!, AT, rand(T, 4), f(rand(T, 4, 4)), rand(T, 4), T(4), T(5))
68+
end
7769

78-
if length(a) > 1 && length(b) > 1
79-
@test compare(*, AT, transpose(rand(T, reverse(a))), transpose(rand(T, reverse(b))))
80-
@test compare(*, AT, adjoint(rand(T, reverse(a))), adjoint(rand(T, reverse(b))))
81-
end
70+
@testset "$T gemm C := $f(A) * $g(B) * a + C * b" for f in (identity, transpose, adjoint), g in (identity, transpose, adjoint), T in supported_eltypes()
71+
@test compare(*, AT, f(rand(T, 4, 4)), g(rand(T, 4, 4)))
72+
@test compare(mul!, AT, rand(T, 4, 4), f(rand(T, 4, 4)), g(rand(T, 4, 4)))
73+
@test compare(mul!, AT, rand(T, 4, 4), f(rand(T, 4, 4)), g(rand(T, 4, 4)), T(4), T(5))
74+
end
8275

83-
@test compare(rmul!, AT, rand(T, a), Ref(rand(T)))
84-
@test compare(lmul!, AT, Ref(rand(T)), rand(T, b))
85-
end
76+
@testset "lmul! and rmul!" for (a,b) in [((3,4),(4,3)), ((3,), (1,3)), ((1,3), (3))], T in supported_eltypes()
77+
@test compare(rmul!, AT, rand(T, a), Ref(rand(T)))
78+
@test compare(lmul!, AT, Ref(rand(T)), rand(T, b))
8679
end
8780
end
8881
end

0 commit comments

Comments
 (0)