Skip to content

Commit b49dd0b

Browse files
committed
Use matmul_size_check for tridiag and triangular
1 parent 50ef837 commit b49dd0b

File tree

2 files changed

+16
-32
lines changed

2 files changed

+16
-32
lines changed

src/bidiag.jl

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -566,22 +566,6 @@ function rmul!(B::Bidiagonal, D::Diagonal)
566566
return B
567567
end
568568

569-
@noinline function check_A_mul_B!_sizes((mC, nC)::NTuple{2,Integer}, (mA, nA)::NTuple{2,Integer}, (mB, nB)::NTuple{2,Integer})
570-
# check for matching sizes in one column of B and C
571-
check_A_mul_B!_sizes((mC,), (mA, nA), (mB,))
572-
# ensure that the number of columns in B and C match
573-
if nB != nC
574-
throw(DimensionMismatch(lazy"second dimension of output C, $nC, and second dimension of B, $nB, must match"))
575-
end
576-
end
577-
@noinline function check_A_mul_B!_sizes((mC,)::Tuple{Integer}, (mA, nA)::NTuple{2,Integer}, (mB,)::Tuple{Integer})
578-
if mA != mC
579-
throw(DimensionMismatch(lazy"first dimension of A, $mA, and first dimension of output C, $mC, must match"))
580-
elseif nA != mB
581-
throw(DimensionMismatch(lazy"second dimension of A, $nA, and first dimension of B, $mB, must match"))
582-
end
583-
end
584-
585569
# function to get the internally stored vectors for Bidiagonal and [Sym]Tridiagonal
586570
# to avoid allocations in _mul! below (#24324, #24578)
587571
_diag(A::Tridiagonal, k) = k == -1 ? A.dl : k == 0 ? A.d : A.du
@@ -603,7 +587,7 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
603587
_bibimul!(C, A, B, _add)
604588
function _bibimul!(C, A, B, _add)
605589
require_one_based_indexing(C)
606-
check_A_mul_B!_sizes(size(C), size(A), size(B))
590+
matmul_size_check(size(C), size(A), size(B))
607591
n = size(A,1)
608592
iszero(n) && return C
609593
# We use `_rmul_or_fill!` instead of `_modify!` here since using
@@ -851,7 +835,7 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number)
851835
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
852836
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
853837
require_one_based_indexing(C)
854-
check_A_mul_B!_sizes(size(C), size(A), size(B))
838+
matmul_size_check(size(C), size(A), size(B))
855839
n = size(A,1)
856840
iszero(n) && return C
857841
_rmul_or_fill!(C, _add.beta) # see the same use above
@@ -894,7 +878,7 @@ end
894878

895879
function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
896880
require_one_based_indexing(C)
897-
check_A_mul_B!_sizes(size(C), size(A), size(B))
881+
matmul_size_check(size(C), size(A), size(B))
898882
n = size(A,1)
899883
iszero(n) && return C
900884
_rmul_or_fill!(C, _add.beta) # see the same use above
@@ -924,7 +908,7 @@ function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
924908
end
925909

926910
function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
927-
check_A_mul_B!_sizes(size(C), size(A), size(B))
911+
matmul_size_check(size(C), size(A), size(B))
928912
n = size(A,1)
929913
iszero(n) && return C
930914
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
@@ -957,7 +941,7 @@ end
957941

958942
function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul)
959943
require_one_based_indexing(C, B)
960-
check_A_mul_B!_sizes(size(C), size(A), size(B))
944+
matmul_size_check(size(C), size(A), size(B))
961945
nA = size(A,1)
962946
nB = size(B,2)
963947
(iszero(nA) || iszero(nB)) && return C
@@ -1027,7 +1011,7 @@ end
10271011

10281012
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
10291013
require_one_based_indexing(C, A)
1030-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1014+
matmul_size_check(size(C), size(A), size(B))
10311015
n = size(A,1)
10321016
m = size(B,2)
10331017
(iszero(_add.alpha) || iszero(m)) && return _rmul_or_fill!(C, _add.beta)
@@ -1063,7 +1047,7 @@ end
10631047

10641048
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul)
10651049
require_one_based_indexing(C, A)
1066-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1050+
matmul_size_check(size(C), size(A), size(B))
10671051
m, n = size(A)
10681052
(iszero(m) || iszero(n)) && return C
10691053
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
@@ -1093,7 +1077,7 @@ _mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul) =
10931077
_dibimul!(C, A, B, _add)
10941078
function _dibimul!(C, A, B, _add)
10951079
require_one_based_indexing(C)
1096-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1080+
matmul_size_check(size(C), size(A), size(B))
10971081
n = size(A,1)
10981082
iszero(n) && return C
10991083
# ensure that we fill off-band elements in the destination
@@ -1137,7 +1121,7 @@ function _dibimul!(C, A, B, _add)
11371121
end
11381122
function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
11391123
require_one_based_indexing(C)
1140-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1124+
matmul_size_check(size(C), size(A), size(B))
11411125
n = size(A,1)
11421126
iszero(n) && return C
11431127
# ensure that we fill off-band elements in the destination
@@ -1168,7 +1152,7 @@ function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
11681152
C
11691153
end
11701154
function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
1171-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1155+
matmul_size_check(size(C), size(A), size(B))
11721156
n = size(A,1)
11731157
n == 0 && return C
11741158
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)

src/triangular.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,7 @@ end
10941094

10951095
for TC in (:AbstractVector, :AbstractMatrix)
10961096
@eval @inline function _mul!(C::$TC, A::AbstractTriangular, B::AbstractVector, alpha::Number, beta::Number)
1097-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1097+
matmul_size_check(size(C), size(A), size(B))
10981098
if isone(alpha) && iszero(beta)
10991099
return _trimul!(C, A, B)
11001100
else
@@ -1107,7 +1107,7 @@ for (TA, TB) in ((:AbstractTriangular, :AbstractMatrix),
11071107
(:AbstractTriangular, :AbstractTriangular)
11081108
)
11091109
@eval @inline function _mul!(C::AbstractMatrix, A::$TA, B::$TB, alpha::Number, beta::Number)
1110-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1110+
matmul_size_check(size(C), size(A), size(B))
11111111
if isone(alpha) && iszero(beta)
11121112
return _trimul!(C, A, B)
11131113
else
@@ -1341,7 +1341,7 @@ end
13411341
## Generic triangular multiplication
13421342
function generic_trimatmul!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractVecOrMat)
13431343
require_one_based_indexing(C, A, B)
1344-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1344+
matmul_size_check(size(C), size(A), size(B))
13451345
oA = oneunit(eltype(A))
13461346
unit = isunitc == 'U'
13471347
@inbounds if uploc == 'U'
@@ -1394,7 +1394,7 @@ end
13941394
# conjugate cases
13951395
function generic_trimatmul!(C::AbstractVecOrMat, uploc, isunitc, ::Function, xA::AdjOrTrans, B::AbstractVecOrMat)
13961396
require_one_based_indexing(C, xA, B)
1397-
check_A_mul_B!_sizes(size(C), size(xA), size(B))
1397+
matmul_size_check(size(C), size(xA), size(B))
13981398
A = parent(xA)
13991399
oA = oneunit(eltype(A))
14001400
unit = isunitc == 'U'
@@ -1424,7 +1424,7 @@ end
14241424

14251425
function generic_mattrimul!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
14261426
require_one_based_indexing(C, A, B)
1427-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1427+
matmul_size_check(size(C), size(A), size(B))
14281428
oB = oneunit(eltype(B))
14291429
unit = isunitc == 'U'
14301430
@inbounds if uploc == 'U'
@@ -1477,7 +1477,7 @@ end
14771477
# conjugate cases
14781478
function generic_mattrimul!(C::AbstractMatrix, uploc, isunitc, ::Function, A::AbstractMatrix, xB::AdjOrTrans)
14791479
require_one_based_indexing(C, A, xB)
1480-
check_A_mul_B!_sizes(size(C), size(A), size(xB))
1480+
matmul_size_check(size(C), size(A), size(xB))
14811481
B = parent(xB)
14821482
oB = oneunit(eltype(B))
14831483
unit = isunitc == 'U'

0 commit comments

Comments
 (0)