Skip to content

Commit eb03abc

Browse files
authored
Merge pull request #252 from JuliaGPU/tb/matmul_fixes
generic_matmul! fixes
2 parents 91446b6 + 4c54dd6 commit eb03abc

File tree

4 files changed

+47
-31
lines changed

4 files changed

+47
-31
lines changed

src/host/linalg.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ end
6666

6767
## matrix multiplication
6868

69-
function generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) where {T,S,R}
69+
function generic_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R}
7070
if size(A,2) != size(B,1)
7171
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
7272
end
@@ -77,6 +77,11 @@ function generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::A
7777
return fill!(C, zero(R))
7878
end
7979

80+
# reshape vectors to matrices
81+
A = reshape(A, (size(A,1), size(A,2)))
82+
B = reshape(B, (size(B,1), size(B,2)))
83+
C = reshape(C, (size(C,1), size(C,2)))
84+
8085
gpu_call(C, A, B) do ctx, C, A, B
8186
idx = @linearidx C
8287
i, j = Tuple(CartesianIndices(C)[idx])

src/reference.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,17 @@ JLArray{T,N}(xs::JLArray{T,N}) where {T,N} = xs
201201

202202
Base.convert(::Type{T}, x::T) where T <: JLArray = x
203203

204+
function Base._reshape(parent::JLArray, dims::Dims)
205+
n = length(parent)
206+
prod(dims) == n || throw(DimensionMismatch("parent has $n elements, which is incompatible with size $dims"))
207+
return JLArray{eltype(parent),length(dims)}(reshape(parent.data, dims), dims)
208+
end
209+
function Base._reshape(parent::JLArray{T,1}, dims::Tuple{Int}) where T
210+
n = length(parent)
211+
prod(dims) == n || throw(DimensionMismatch("parent has $n elements, which is incompatible with size $dims"))
212+
return parent
213+
end
214+
204215

205216
## broadcast
206217

test/testsuite/base.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ function test_base(AT)
8888
@test compare((a,b) -> cat(a, b; dims=4), AT, rand(Float32, 3, 4), rand(Float32, 3, 4))
8989
end
9090

91+
@testset "reshape" begin
92+
@test compare(reshape, AT, rand(10), Ref((10,)))
93+
@test compare(reshape, AT, rand(10), Ref((10,1)))
94+
@test compare(reshape, AT, rand(10), Ref((1,10)))
95+
96+
@test reshape(AT(rand(10)), (10,1)) isa AT
97+
@test_throws Exception reshape(AT(rand(10)), (10,2))
98+
end
99+
91100
@testset "reinterpret" begin
92101
a = rand(ComplexF32, 22)
93102
A = AT(a)

test/testsuite/linalg.jl

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -62,36 +62,27 @@ function test_linalg(AT)
6262
end
6363

6464
@testset "matrix multiplication" begin
65-
a = rand(Int8, 3, 3)
66-
b = rand(Int8, 3, 3)
67-
d_a = AT{Int8}(a)
68-
d_b = AT{Int8}(b)
69-
d_c = d_a*d_b
70-
@test collect(d_c) == a*b
71-
a = rand(Complex{Int8}, 3, 3)
72-
b = rand(Complex{Int8}, 3, 3)
73-
d_a = AT{Complex{Int8}}(a)
74-
d_b = AT{Complex{Int8}}(b)
75-
d_c = d_a'*d_b
76-
@test collect(d_c) == a'*b
77-
d_c = d_a*d_b'
78-
@test collect(d_c) == a*b'
79-
d_c = d_a'*d_b'
80-
@test collect(d_c) == a'*b'
81-
d_c = transpose(d_a)*d_b'
82-
@test collect(d_c) == transpose(a)*b'
83-
d_c = d_a'*transpose(d_b)
84-
@test collect(d_c) == a'*transpose(b)
85-
d_c = transpose(d_a)*d_b
86-
@test collect(d_c) == transpose(a)*b
87-
d_c = d_a*transpose(d_b)
88-
@test collect(d_c) == a*transpose(b)
89-
d_c = transpose(d_a)*transpose(d_b)
90-
@test collect(d_c) == transpose(a)*transpose(b)
91-
d_c = rmul!(copy(d_a), Complex{Int8}(2, 2))
92-
@test collect(d_c) == a*Complex{Int8}(2, 2)
93-
d_c = lmul!(Complex{Int8}(2, 2), copy(d_a))
94-
@test collect(d_c) == Complex{Int8}(2, 2)*a
65+
for (a,b) in [((3,4),(4,3)), ((3,), (1,3)), ((1,3), (3))], T in supported_eltypes()
66+
@test compare(*, AT, rand(T, a), rand(T, b))
67+
68+
if length(a) > 1
69+
@test compare(*, AT, transpose(rand(T, reverse(a))), rand(T, b))
70+
@test compare(*, AT, adjoint(rand(T, reverse(a))), rand(T, b))
71+
end
72+
73+
if length(b) > 1
74+
@test compare(*, AT, rand(T, a), transpose(rand(T, reverse(b))))
75+
@test compare(*, AT, rand(T, a), adjoint(rand(T, reverse(b))))
76+
end
77+
78+
if length(a) > 1 && length(b) > 1
79+
@test compare(*, AT, transpose(rand(T, reverse(a))), transpose(rand(T, reverse(b))))
80+
@test compare(*, AT, adjoint(rand(T, reverse(a))), adjoint(rand(T, reverse(b))))
81+
end
82+
83+
@test compare(rmul!, AT, rand(T, a), Ref(rand(T)))
84+
@test compare(lmul!, AT, Ref(rand(T)), rand(T, b))
85+
end
9586
end
9687
end
9788
end

0 commit comments

Comments
 (0)