Skip to content

Commit d9fc5ea

Browse files
committed
Consistently check matrix sizes in matmul
1 parent 6e5ea12 commit d9fc5ea

File tree

4 files changed

+69
-82
lines changed

4 files changed

+69
-82
lines changed

src/bidiag.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ end
497497

498498
# B .= A * B
499499
function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
500-
_muldiag_size_check(size(A), size(B))
500+
matmul_size_check(size(A), size(B))
501501
(; dv, ev) = A
502502
if A.uplo == 'U'
503503
for k in axes(B,2)
@@ -518,7 +518,7 @@ function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
518518
end
519519
# B .= D * B
520520
function lmul!(D::Diagonal, B::Bidiagonal)
521-
_muldiag_size_check(size(D), size(B))
521+
matmul_size_check(size(D), size(B))
522522
(; dv, ev) = B
523523
isL = B.uplo == 'L'
524524
dv[1] = D.diag[1] * dv[1]
@@ -530,7 +530,7 @@ function lmul!(D::Diagonal, B::Bidiagonal)
530530
end
531531
# B .= B * A
532532
function rmul!(B::AbstractMatrix, A::Bidiagonal)
533-
_muldiag_size_check(size(A), size(B))
533+
matmul_size_check(size(A), size(B))
534534
(; dv, ev) = A
535535
if A.uplo == 'U'
536536
for k in reverse(axes(dv,1)[2:end])
@@ -555,7 +555,7 @@ function rmul!(B::AbstractMatrix, A::Bidiagonal)
555555
end
556556
# B .= B * D
557557
function rmul!(B::Bidiagonal, D::Diagonal)
558-
_muldiag_size_check(size(B), size(D))
558+
matmul_size_check(size(B), size(D))
559559
(; dv, ev) = B
560560
isU = B.uplo == 'U'
561561
dv[1] *= D.diag[1]

src/diagonal.jl

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -322,39 +322,18 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) =
322322
Diagonal(Base.literal_pow.(^, D.diag, valp)) # for speed
323323
Base.literal_pow(::typeof(^), D::Diagonal, ::Val{-1}) = inv(D) # for disambiguation
324324

325-
function _muldiag_size_check(szA::NTuple{2,Integer}, szB::Tuple{Integer,Vararg{Integer}})
326-
nA = szA[2]
327-
mB = szB[1]
328-
@noinline throw_dimerr(szB::NTuple{2}, nA, mB) = throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match first dimension of B, $mB"))
329-
@noinline throw_dimerr(szB::NTuple{1}, nA, mB) = throw(DimensionMismatch(lazy"second dimension of D, $nA, does not match length of V, $mB"))
330-
nA == mB || throw_dimerr(szB, nA, mB)
331-
return nothing
332-
end
333-
# the output matrix should have the same size as the non-diagonal input matrix or vector
334-
@noinline throw_dimerr(szC, szA) = throw(DimensionMismatch(lazy"output matrix has size: $szC, but should have size $szA"))
335-
function _size_check_out(szC::NTuple{2}, szA::NTuple{2}, szB::NTuple{2})
336-
(szC[1] == szA[1] && szC[2] == szB[2]) || throw_dimerr(szC, (szA[1], szB[2]))
337-
end
338-
function _size_check_out(szC::NTuple{1}, szA::NTuple{2}, szB::NTuple{1})
339-
szC[1] == szA[1] || throw_dimerr(szC, (szA[1],))
340-
end
341-
function _muldiag_size_check(szC::Tuple{Vararg{Integer}}, szA::Tuple{Vararg{Integer}}, szB::Tuple{Vararg{Integer}})
342-
_muldiag_size_check(szA, szB)
343-
_size_check_out(szC, szA, szB)
344-
end
345-
346325
function (*)(Da::Diagonal, Db::Diagonal)
347-
_muldiag_size_check(size(Da), size(Db))
326+
matmul_size_check(size(Da), size(Db))
348327
return Diagonal(Da.diag .* Db.diag)
349328
end
350329

351330
function (*)(D::Diagonal, V::AbstractVector)
352-
_muldiag_size_check(size(D), size(V))
331+
matmul_size_check(size(D), size(V))
353332
return D.diag .* V
354333
end
355334

356335
function rmul!(A::AbstractMatrix, D::Diagonal)
357-
_muldiag_size_check(size(A), size(D))
336+
matmul_size_check(size(A), size(D))
358337
for I in CartesianIndices(A)
359338
row, col = Tuple(I)
360339
@inbounds A[row, col] *= D.diag[col]
@@ -363,7 +342,7 @@ function rmul!(A::AbstractMatrix, D::Diagonal)
363342
end
364343
# T .= T * D
365344
function rmul!(T::Tridiagonal, D::Diagonal)
366-
_muldiag_size_check(size(T), size(D))
345+
matmul_size_check(size(T), size(D))
367346
(; dl, d, du) = T
368347
d[1] *= D.diag[1]
369348
for i in axes(dl,1)
@@ -375,7 +354,7 @@ function rmul!(T::Tridiagonal, D::Diagonal)
375354
end
376355

377356
function lmul!(D::Diagonal, B::AbstractVecOrMat)
378-
_muldiag_size_check(size(D), size(B))
357+
matmul_size_check(size(D), size(B))
379358
for I in CartesianIndices(B)
380359
row = I[1]
381360
@inbounds B[I] = D.diag[row] * B[I]
@@ -386,7 +365,7 @@ end
386365
# in-place multiplication with a diagonal
387366
# T .= D * T
388367
function lmul!(D::Diagonal, T::Tridiagonal)
389-
_muldiag_size_check(size(D), size(T))
368+
matmul_size_check(size(D), size(T))
390369
(; dl, d, du) = T
391370
d[1] = D.diag[1] * d[1]
392371
for i in axes(dl,1)
@@ -507,7 +486,7 @@ end
507486
# specialize the non-trivial case
508487
function _mul_diag!(out, A, B, alpha, beta)
509488
require_one_based_indexing(out, A, B)
510-
_muldiag_size_check(size(out), size(A), size(B))
489+
matmul_size_check(size(out), size(A), size(B))
511490
if iszero(alpha)
512491
_rmul_or_fill!(out, beta)
513492
else
@@ -532,14 +511,14 @@ _mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number
532511
_mul_diag!(C, Da, Db, alpha, beta)
533512

534513
function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal)
535-
_muldiag_size_check(size(Da), size(A))
536-
_muldiag_size_check(size(A), size(Db))
514+
matmul_size_check(size(Da), size(A))
515+
matmul_size_check(size(A), size(Db))
537516
return broadcast(*, Da.diag, A, permutedims(Db.diag))
538517
end
539518

540519
function (*)(Da::Diagonal, Db::Diagonal, Dc::Diagonal)
541-
_muldiag_size_check(size(Da), size(Db))
542-
_muldiag_size_check(size(Db), size(Dc))
520+
matmul_size_check(size(Da), size(Db))
521+
matmul_size_check(size(Db), size(Dc))
543522
return Diagonal(Da.diag .* Db.diag .* Dc.diag)
544523
end
545524

src/matmul.jl

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,40 @@ julia> lmul!(F.Q, B)
408408
"""
409409
lmul!(A, B)
410410

411+
_vec_or_mat_str(s::Tuple{Any}) = :vector
412+
_vec_or_mat_str(s::Tuple{Any,Any}) = :matrix
413+
@noinline function matmul_size_check(sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}})
414+
strA = _vec_or_mat_str(sizeA)
415+
strB = _vec_or_mat_str(sizeB)
416+
szA2 = get(sizeA, 2, 1)
417+
if szA2 != sizeB[1]
418+
throw(DimensionMismatch(
419+
lazy"incompatible dimensions for matrix multiplication: tried to multiply a $strA of size $sizeA with a $strB of size $sizeB.",
420+
)
421+
)
422+
end
423+
return nothing
424+
end
425+
@noinline function matmul_size_check(sizeC::Tuple{Integer,Vararg{Integer}}, sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}})
426+
strA = _vec_or_mat_str(sizeA)
427+
strB = _vec_or_mat_str(sizeB)
428+
szB2 = get(sizeB, 2, 1)
429+
szC2 = get(sizeC, 2, 1)
430+
matmul_size_check(sizeA, sizeB)
431+
if sizeC[1] != sizeA[1] || szC2 != szB2
432+
destsize = length(sizeB) == length(sizeC) == 1 ? (sizeA[1],) : (sizeA[1], szB2)
433+
throw(DimensionMismatch(
434+
LazyString(
435+
"incompatible destination size: ",
436+
lazy"the destination of size $sizeC is incomatible with the multiplication of a $strA of size $(sizeA) and a $strB of size $(sizeB). ",
437+
lazy"The destination must be of size $destsize."
438+
)
439+
)
440+
)
441+
end
442+
return nothing
443+
end
444+
411445
# We may inline the matmul2x2! and matmul3x3! calls for `α == true`
412446
# to simplify the @stable_muladdmul branches
413447
function matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β)
@@ -441,9 +475,7 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
441475
mA, nA = lapack_size(tA, A)
442476
mB, nB = lapack_size(tB, B)
443477
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
444-
if size(C) != (mA, nB)
445-
throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)"))
446-
end
478+
matmul_size_check(size(C), (mA, nA), (mB, nB))
447479
return _rmul_or_fill!(C, β)
448480
end
449481
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
@@ -475,9 +507,7 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
475507
mA, nA = lapack_size(tA, A)
476508
mB, nB = lapack_size(tB, B)
477509
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
478-
if size(C) != (mA, nB)
479-
throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)"))
480-
end
510+
matmul_size_check(size(C), (mA, nA), (mB, nB))
481511
return _rmul_or_fill!(C, β)
482512
end
483513
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
@@ -571,10 +601,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
571601
A::StridedVecOrMat{T}, x::StridedVector{T},
572602
α::Number=true, β::Number=false) where {T<:BlasFloat}
573603
mA, nA = lapack_size(tA, A)
574-
nA != length(x) &&
575-
throw(DimensionMismatch(lazy"second dimension of matrix, $nA, does not match length of input vector, $(length(x))"))
576-
mA != length(y) &&
577-
throw(DimensionMismatch(lazy"first dimension of matrix, $mA, does not match length of output vector, $(length(y))"))
604+
matmul_size_check(size(y), (mA, nA), size(x))
578605
mA == 0 && return y
579606
nA == 0 && return _rmul_or_fill!(y, β)
580607
alpha, beta = promote(α, β, zero(T))
@@ -602,10 +629,7 @@ end
602629
Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
603630
α::Number = true, β::Number = false) where {T<:BlasReal}
604631
mA, nA = lapack_size(tA, A)
605-
nA != length(x) &&
606-
throw(DimensionMismatch(lazy"second dimension of matrix, $nA, does not match length of input vector, $(length(x))"))
607-
mA != length(y) &&
608-
throw(DimensionMismatch(lazy"first dimension of matrix, $mA, does not match length of output vector, $(length(y))"))
632+
matmul_size_check(size(y), (mA, nA), size(x))
609633
mA == 0 && return y
610634
nA == 0 && return _rmul_or_fill!(y, β)
611635
alpha, beta = promote(α, β, zero(T))
@@ -626,10 +650,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
626650
A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
627651
α::Number = true, β::Number = false) where {T<:BlasReal}
628652
mA, nA = lapack_size(tA, A)
629-
nA != length(x) &&
630-
throw(DimensionMismatch(lazy"second dimension of matrix, $nA, does not match length of input vector, $(length(x))"))
631-
mA != length(y) &&
632-
throw(DimensionMismatch(lazy"first dimension of matrix, $mA, does not match length of output vector, $(length(y))"))
653+
matmul_size_check(size(y), (mA, nA), size(x))
633654
mA == 0 && return y
634655
nA == 0 && return _rmul_or_fill!(y, β)
635656
alpha, beta = promote(α, β, zero(T))
@@ -748,9 +769,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab
748769
mA, nA = lapack_size(tA, A)
749770
mB, nB = lapack_size(tB, B)
750771

751-
if nA != mB
752-
throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
753-
end
772+
matmul_size_check(size(C), (mA, nA), (mB, nB))
754773

755774
if C === A || B === C
756775
throw(ArgumentError("output matrix must not be aliased with input matrix"))
@@ -778,9 +797,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}
778797
mA, nA = lapack_size(tA, A)
779798
mB, nB = lapack_size(tB, B)
780799

781-
if nA != mB
782-
throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
783-
end
800+
matmul_size_check(size(C), (mA, nA), (mB, nB))
784801

785802
if C === A || B === C
786803
throw(ArgumentError("output matrix must not be aliased with input matrix"))
@@ -940,14 +957,8 @@ function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::Abst
940957
alpha::Number, beta::Number)
941958
require_one_based_indexing(C, A, B)
942959
@assert tA in ('N', 'T', 'C')
943-
mB = length(B)
944960
mA, nA = lapack_size(tA, A)
945-
if mB != nA
946-
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB"))
947-
end
948-
if mA != length(C)
949-
throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA"))
950-
end
961+
matmul_size_check(size(C), (mA, nA), size(B))
951962

952963
if tA == 'T' # fastest case
953964
__generic_matvecmul!(transpose, C, A, B, alpha, beta)
@@ -979,21 +990,7 @@ _generic_matmatmul!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMa
979990

980991
@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat, B::AbstractVecOrMat,
981992
alpha::Number, beta::Number) where {R}
982-
AxM = axes(A, 1)
983-
AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
984-
BxK = axes(B, 1)
985-
BxN = axes(B, 2)
986-
CxM = axes(C, 1)
987-
CxN = axes(C, 2)
988-
if AxM != CxM
989-
throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix C has axes ($CxM,$CxN)"))
990-
end
991-
if AxK != BxK
992-
throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix B has axes ($BxK,$CxN)"))
993-
end
994-
if BxN != CxN
995-
throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
996-
end
993+
matmul_size_check(size(C), size(A), size(B))
997994
__generic_matmatmul!(C, A, B, alpha, beta, Val(isbitstype(R) && sizeof(R) 16))
998995
return C
999996
end

test/matmul.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,4 +1148,15 @@ end
11481148
@test A * A M * M
11491149
end
11501150

1151+
@testset "issue #1147: error messages in matmul" begin
1152+
@test_throws "incompatible dimensions for matrix multiplication" zeros(0,0) * zeros(1,5)
1153+
@test_throws "incompatible dimensions for matrix multiplication" zeros(0,0) * zeros(1)
1154+
@test_throws "incompatible dimensions for matrix multiplication" zeros(0) * zeros(2,5)
1155+
@test_throws "incompatible dimensions for matrix multiplication" mul!(zeros(0,0), zeros(5), zeros(5))
1156+
@test_throws "incompatible dimensions for matrix multiplication" mul!(zeros(0,0), zeros(1,5), zeros(0,0))
1157+
@test_throws "incompatible destination size" mul!(zeros(0,0), zeros(1,5), zeros(5,2))
1158+
@test_throws "incompatible destination size" mul!(zeros(0,0), zeros(1,5), zeros(5))
1159+
@test_throws "incompatible destination size" mul!(zeros(0), zeros(1,5), zeros(5))
1160+
end
1161+
11511162
end # module TestMatmul

0 commit comments

Comments
 (0)