Skip to content

Commit 4c54dd6

Browse files
committed
Fix generic matmul with 1D inputs.
1 parent 4d99813 commit 4c54dd6

File tree

2 files changed

+27
-31
lines changed

2 files changed

+27
-31
lines changed

src/host/linalg.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ end
6666

6767
## matrix multiplication
6868

69-
function generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) where {T,S,R}
69+
function generic_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R}
7070
if size(A,2) != size(B,1)
7171
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
7272
end
@@ -77,6 +77,11 @@ function generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::A
7777
return fill!(C, zero(R))
7878
end
7979

80+
# reshape vectors to matrices
81+
A = reshape(A, (size(A,1), size(A,2)))
82+
B = reshape(B, (size(B,1), size(B,2)))
83+
C = reshape(C, (size(C,1), size(C,2)))
84+
8085
gpu_call(C, A, B) do ctx, C, A, B
8186
idx = @linearidx C
8287
i, j = Tuple(CartesianIndices(C)[idx])

test/testsuite/linalg.jl

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -62,36 +62,27 @@ function test_linalg(AT)
6262
end
6363

6464
@testset "matrix multiplication" begin
65-
a = rand(Int8, 3, 3)
66-
b = rand(Int8, 3, 3)
67-
d_a = AT{Int8}(a)
68-
d_b = AT{Int8}(b)
69-
d_c = d_a*d_b
70-
@test collect(d_c) == a*b
71-
a = rand(Complex{Int8}, 3, 3)
72-
b = rand(Complex{Int8}, 3, 3)
73-
d_a = AT{Complex{Int8}}(a)
74-
d_b = AT{Complex{Int8}}(b)
75-
d_c = d_a'*d_b
76-
@test collect(d_c) == a'*b
77-
d_c = d_a*d_b'
78-
@test collect(d_c) == a*b'
79-
d_c = d_a'*d_b'
80-
@test collect(d_c) == a'*b'
81-
d_c = transpose(d_a)*d_b'
82-
@test collect(d_c) == transpose(a)*b'
83-
d_c = d_a'*transpose(d_b)
84-
@test collect(d_c) == a'*transpose(b)
85-
d_c = transpose(d_a)*d_b
86-
@test collect(d_c) == transpose(a)*b
87-
d_c = d_a*transpose(d_b)
88-
@test collect(d_c) == a*transpose(b)
89-
d_c = transpose(d_a)*transpose(d_b)
90-
@test collect(d_c) == transpose(a)*transpose(b)
91-
d_c = rmul!(copy(d_a), Complex{Int8}(2, 2))
92-
@test collect(d_c) == a*Complex{Int8}(2, 2)
93-
d_c = lmul!(Complex{Int8}(2, 2), copy(d_a))
94-
@test collect(d_c) == Complex{Int8}(2, 2)*a
65+
for (a,b) in [((3,4),(4,3)), ((3,), (1,3)), ((1,3), (3))], T in supported_eltypes()
66+
@test compare(*, AT, rand(T, a), rand(T, b))
67+
68+
if length(a) > 1
69+
@test compare(*, AT, transpose(rand(T, reverse(a))), rand(T, b))
70+
@test compare(*, AT, adjoint(rand(T, reverse(a))), rand(T, b))
71+
end
72+
73+
if length(b) > 1
74+
@test compare(*, AT, rand(T, a), transpose(rand(T, reverse(b))))
75+
@test compare(*, AT, rand(T, a), adjoint(rand(T, reverse(b))))
76+
end
77+
78+
if length(a) > 1 && length(b) > 1
79+
@test compare(*, AT, transpose(rand(T, reverse(a))), transpose(rand(T, reverse(b))))
80+
@test compare(*, AT, adjoint(rand(T, reverse(a))), adjoint(rand(T, reverse(b))))
81+
end
82+
83+
@test compare(rmul!, AT, rand(T, a), Ref(rand(T)))
84+
@test compare(lmul!, AT, Ref(rand(T)), rand(T, b))
85+
end
9586
end
9687
end
9788
end

0 commit comments

Comments
 (0)