Skip to content

Commit 9b7cfdd

Browse files
committed
Use matmul_size_check in 2x2 and 3x3 matmul
1 parent 7e1ad94 commit 9b7cfdd

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/matmul.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,8 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
521521
α::Number, β::Number, val::BlasFlag.SymmHemmGeneric) where {T<:BlasFloat}
522522
mA, nA = lapack_size(tA, A)
523523
mB, nB = lapack_size(tB, B)
524+
matmul_size_check(size(C), (mA, nA), (mB, nB))
524525
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
525-
matmul_size_check(size(C), (mA, nA), (mB, nB))
526526
return _rmul_or_fill!(C, β)
527527
end
528528
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
@@ -701,7 +701,7 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst
701701
tAt = 'T'
702702
end
703703
if nC != mA
704-
throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)"))
704+
throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))"))
705705
end
706706

707707
# BLAS.syrk! only updates symmetric C
@@ -735,7 +735,7 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St
735735
tAt = 'C'
736736
end
737737
if nC != mA
738-
throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)"))
738+
throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))"))
739739
end
740740

741741
# Result array does not need to be initialized as long as beta==0
@@ -1067,11 +1067,12 @@ end
10671067

10681068
function __matmul_checks(C, A, B, sz)
10691069
require_one_based_indexing(C, A, B)
1070+
matmul_size_check(size(C), size(A), size(B))
10701071
if C === A || B === C
10711072
throw(ArgumentError("output matrix must not be aliased with input matrix"))
10721073
end
10731074
if !(size(A) == size(B) == size(C) == sz)
1074-
throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
1075+
throw(DimensionMismatch(lazy"expected size: $sz, but got $(size(A))"))
10751076
end
10761077
return nothing
10771078
end

0 commit comments

Comments
 (0)