@@ -408,6 +408,40 @@ julia> lmul!(F.Q, B)
408408"""
409409lmul! (A, B)
410410
411+ _vec_or_mat_str (s:: Tuple{Any} ) = :vector
412+ _vec_or_mat_str (s:: Tuple{Any,Any} ) = :matrix
413+ @noinline function matmul_size_check (sizeA:: Tuple{Integer,Vararg{Integer}} , sizeB:: Tuple{Integer,Vararg{Integer}} )
414+ strA = _vec_or_mat_str (sizeA)
415+ strB = _vec_or_mat_str (sizeB)
416+ szA2 = get (sizeA, 2 , 1 )
417+ if szA2 != sizeB[1 ]
418+ throw (DimensionMismatch (
419+ lazy " incompatible dimensions for matrix multiplication: tried to multiply a $strA of size $sizeA with a $strB of size $sizeB." ,
420+ )
421+ )
422+ end
423+ return nothing
424+ end
425+ @noinline function matmul_size_check (sizeC:: Tuple{Integer,Vararg{Integer}} , sizeA:: Tuple{Integer,Vararg{Integer}} , sizeB:: Tuple{Integer,Vararg{Integer}} )
426+ strA = _vec_or_mat_str (sizeA)
427+ strB = _vec_or_mat_str (sizeB)
428+ szB2 = get (sizeB, 2 , 1 )
429+ szC2 = get (sizeC, 2 , 1 )
430+ matmul_size_check (sizeA, sizeB)
431+ if sizeC[1 ] != sizeA[1 ] || szC2 != szB2
432+ destsize = length (sizeB) == length (sizeC) == 1 ? (sizeA[1 ],) : (sizeA[1 ], szB2)
433+ throw (DimensionMismatch (
434+ LazyString (
435+ " incompatible destination size: " ,
436+ lazy " the destination of size $sizeC is incomatible with the multiplication of a $strA of size $(sizeA) and a $strB of size $(sizeB). " ,
437+ lazy " The destination must be of size $destsize."
438+ )
439+ )
440+ )
441+ end
442+ return nothing
443+ end
444+
411445# We may inline the matmul2x2! and matmul3x3! calls for `α == true`
412446# to simplify the @stable_muladdmul branches
413447function matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β)
@@ -441,9 +475,7 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
441475 mA, nA = lapack_size (tA, A)
442476 mB, nB = lapack_size (tB, B)
443477 if any (iszero, size (A)) || any (iszero, size (B)) || iszero (α)
444- if size (C) != (mA, nB)
445- throw (DimensionMismatch (lazy " C has dimensions $(size(C)), should have ($mA,$nB)" ))
446- end
478+ matmul_size_check (size (C), (mA, nA), (mB, nB))
447479 return _rmul_or_fill! (C, β)
448480 end
449481 matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β) && return C
@@ -475,9 +507,7 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
475507 mA, nA = lapack_size (tA, A)
476508 mB, nB = lapack_size (tB, B)
477509 if any (iszero, size (A)) || any (iszero, size (B)) || iszero (α)
478- if size (C) != (mA, nB)
479- throw (DimensionMismatch (lazy " C has dimensions $(size(C)), should have ($mA,$nB)" ))
480- end
510+ matmul_size_check (size (C), (mA, nA), (mB, nB))
481511 return _rmul_or_fill! (C, β)
482512 end
483513 matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β) && return C
@@ -571,10 +601,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
571601 A:: StridedVecOrMat{T} , x:: StridedVector{T} ,
572602 α:: Number = true , β:: Number = false ) where {T<: BlasFloat }
573603 mA, nA = lapack_size (tA, A)
574- nA != length (x) &&
575- throw (DimensionMismatch (lazy " second dimension of matrix, $nA, does not match length of input vector, $(length(x))" ))
576- mA != length (y) &&
577- throw (DimensionMismatch (lazy " first dimension of matrix, $mA, does not match length of output vector, $(length(y))" ))
604+ matmul_size_check (size (y), (mA, nA), size (x))
578605 mA == 0 && return y
579606 nA == 0 && return _rmul_or_fill! (y, β)
580607 alpha, beta = promote (α, β, zero (T))
602629Base. @constprop :aggressive function gemv! (y:: StridedVector{Complex{T}} , tA:: AbstractChar , A:: StridedVecOrMat{Complex{T}} , x:: StridedVector{T} ,
603630 α:: Number = true , β:: Number = false ) where {T<: BlasReal }
604631 mA, nA = lapack_size (tA, A)
605- nA != length (x) &&
606- throw (DimensionMismatch (lazy " second dimension of matrix, $nA, does not match length of input vector, $(length(x))" ))
607- mA != length (y) &&
608- throw (DimensionMismatch (lazy " first dimension of matrix, $mA, does not match length of output vector, $(length(y))" ))
632+ matmul_size_check (size (y), (mA, nA), size (x))
609633 mA == 0 && return y
610634 nA == 0 && return _rmul_or_fill! (y, β)
611635 alpha, beta = promote (α, β, zero (T))
@@ -626,10 +650,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
626650 A:: StridedVecOrMat{T} , x:: StridedVector{Complex{T}} ,
627651 α:: Number = true , β:: Number = false ) where {T<: BlasReal }
628652 mA, nA = lapack_size (tA, A)
629- nA != length (x) &&
630- throw (DimensionMismatch (lazy " second dimension of matrix, $nA, does not match length of input vector, $(length(x))" ))
631- mA != length (y) &&
632- throw (DimensionMismatch (lazy " first dimension of matrix, $mA, does not match length of output vector, $(length(y))" ))
653+ matmul_size_check (size (y), (mA, nA), size (x))
633654 mA == 0 && return y
634655 nA == 0 && return _rmul_or_fill! (y, β)
635656 alpha, beta = promote (α, β, zero (T))
@@ -748,9 +769,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab
748769 mA, nA = lapack_size (tA, A)
749770 mB, nB = lapack_size (tB, B)
750771
751- if nA != mB
752- throw (DimensionMismatch (lazy " A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)" ))
753- end
772+ matmul_size_check (size (C), (mA, nA), (mB, nB))
754773
755774 if C === A || B === C
756775 throw (ArgumentError (" output matrix must not be aliased with input matrix" ))
@@ -778,9 +797,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}
778797 mA, nA = lapack_size (tA, A)
779798 mB, nB = lapack_size (tB, B)
780799
781- if nA != mB
782- throw (DimensionMismatch (lazy " A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)" ))
783- end
800+ matmul_size_check (size (C), (mA, nA), (mB, nB))
784801
785802 if C === A || B === C
786803 throw (ArgumentError (" output matrix must not be aliased with input matrix" ))
@@ -940,14 +957,8 @@ function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::Abst
940957 alpha:: Number , beta:: Number )
941958 require_one_based_indexing (C, A, B)
942959 @assert tA in (' N' , ' T' , ' C' )
943- mB = length (B)
944960 mA, nA = lapack_size (tA, A)
945- if mB != nA
946- throw (DimensionMismatch (lazy " matrix A has dimensions ($mA,$nA), vector B has length $mB" ))
947- end
948- if mA != length (C)
949- throw (DimensionMismatch (lazy " result C has length $(length(C)), needs length $mA" ))
950- end
961+ matmul_size_check (size (C), (mA, nA), size (B))
951962
952963 if tA == ' T' # fastest case
953964 __generic_matvecmul! (transpose, C, A, B, alpha, beta)
@@ -979,21 +990,7 @@ _generic_matmatmul!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMa
979990
980991@noinline function _generic_matmatmul! (C:: AbstractVecOrMat{R} , A:: AbstractVecOrMat , B:: AbstractVecOrMat ,
981992 alpha:: Number , beta:: Number ) where {R}
982- AxM = axes (A, 1 )
983- AxK = axes (A, 2 ) # we use two `axes` calls in case of `AbstractVector`
984- BxK = axes (B, 1 )
985- BxN = axes (B, 2 )
986- CxM = axes (C, 1 )
987- CxN = axes (C, 2 )
988- if AxM != CxM
989- throw (DimensionMismatch (lazy " matrix A has axes ($AxM,$AxK), matrix C has axes ($CxM,$CxN)" ))
990- end
991- if AxK != BxK
992- throw (DimensionMismatch (lazy " matrix A has axes ($AxM,$AxK), matrix B has axes ($BxK,$CxN)" ))
993- end
994- if BxN != CxN
995- throw (DimensionMismatch (lazy " matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)" ))
996- end
993+ matmul_size_check (size (C), size (A), size (B))
997994 __generic_matmatmul! (C, A, B, alpha, beta, Val (isbitstype (R) && sizeof (R) ≤ 16 ))
998995 return C
999996end
0 commit comments