@@ -599,11 +599,17 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
599599 throw (DimensionMismatch (lazy " output matrix has size: $(size(C)), but should have size $((mA, mA))" ))
600600 end
601601
602- _rmul_or_fill! (C, β)
602+ if (! iszero (β) || isempty (A)) # return C*beta
603+ _rmul_or_fill! (C, β)
604+ else # iszero(β) && A is non-empty
605+ aA_11 = abs2 (A[1 ,1 ])
606+ fill! (UpperTriangular (C), zero (aA_11 + aA_11))
607+ end
608+ iszero (α) && return C
603609 @inbounds if ! conjugate
604610 if aat
605611 for k ∈ 1 : n, j ∈ 1 : m
606- αA_jk = A[j, k] * α
612+ αA_jk = @stable_muladdmul MulAddMul (α, false )( A[j, k])
607613 for i ∈ 1 : j
608614 C[i, j] += A[i, k] * αA_jk
609615 end
@@ -614,17 +620,17 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
614620 for k ∈ 2 : m
615621 temp += A[k, i] * A[k, j]
616622 end
617- C[i, j] += temp * α
623+ C[i, j] += @stable_muladdmul MulAddMul (α, false )(temp)
618624 end
619625 end
620626 else
621627 if aat
622628 for k ∈ 1 : n, j ∈ 1 : m
623- αA_jk_bar = conj (A[j, k]) * α
629+ αA_jk_bar = @stable_muladdmul MulAddMul (α, false )( conj (A[j, k]))
624630 for i ∈ 1 : j- 1
625631 C[i, j] += A[i, k] * αA_jk_bar
626632 end
627- C[j, j] += abs2 (A[j, k]) * α
633+ C[j, j] += @stable_muladdmul MulAddMul (α, false )( abs2 (A[j, k]))
628634 end
629635 else
630636 for j ∈ 1 : n
@@ -633,13 +639,13 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
633639 for k ∈ 2 : m
634640 temp += conj (A[k, i]) * A[k, j]
635641 end
636- C[i, j] += temp * α
642+ C[i, j] += @stable_muladdmul MulAddMul (α, false )(temp)
637643 end
638644 temp = abs2 (A[1 , j])
639645 for k ∈ 2 : m
640646 temp += abs2 (A[k, j])
641647 end
642- C[j, j] += temp * α
648+ C[j, j] += @stable_muladdmul MulAddMul (α, false )(temp)
643649 end
644650 end
645651 end
@@ -1132,8 +1138,21 @@ __generic_matmatmul!(C, A, B, alpha, beta, ::Val{true}) = _generic_matmatmul_non
11321138__generic_matmatmul! (C, A, B, alpha, beta, :: Val{false} ) = _generic_matmatmul_generic! (C, A, B, alpha, beta)
11331139
11341140function _generic_matmatmul_nonadjtrans! (C, A, B, alpha, beta)
1135- _rmul_or_fill! (C, beta)
1136- (iszero (alpha) || isempty (A) || isempty (B)) && return C
1141+ # _rmul_or_fill!(C, beta) spelled out more carefully to allow for zero-less eltypes
1142+ if (! iszero (beta) || isempty (A) || isempty (B)) # return C*beta
1143+ _rmul_or_fill! (C, beta)
1144+ else # iszero(beta) && A and B are non-empty
1145+ a1 = firstindex (A, 2 )
1146+ b1 = firstindex (B, 1 )
1147+ for j in axes (C, 2 )
1148+ B_1j = B[b1, j]
1149+ for i in axes (C, 1 )
1150+ C_ij = A[i, a1] * B_1j
1151+ C[i,j] = zero (C_ij + C_ij)
1152+ end
1153+ end
1154+ end
1155+ iszero (alpha) && return C
11371156 @inbounds for n in axes (B, 2 ), k in axes (B, 1 )
11381157 # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
11391158 Balpha = @stable_muladdmul MulAddMul (alpha, false )(B[k,n])
@@ -1145,20 +1164,40 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta)
11451164 C
11461165end
11471166function _generic_matmatmul_adjtrans! (C, A, B, alpha, beta)
1148- _rmul_or_fill! (C, beta)
1149- (iszero (alpha) || isempty (A) || isempty (B)) && return C
11501167 t = _wrapperop (A)
11511168 pB = parent (B)
11521169 pA = parent (A)
1170+ if (! iszero (beta) || isempty (A) || isempty (B)) # return C*beta
1171+ _rmul_or_fill! (C, beta)
1172+ else # iszero(beta) && A and B are non-empty
1173+ a1 = firstindex (pA, 1 )
1174+ b1 = firstindex (pB, 2 )
1175+ for j in axes (C, 2 )
1176+ tB_1j = t (pB[j, b1])
1177+ for i in axes (C, 1 )
1178+ C_ij = t (pA[a1, i]) * tB_1j
1179+ C[i,j] = zero (C_ij + C_ij)
1180+ end
1181+ end
1182+ end
1183+ iszero (alpha) && return C
11531184 tmp = similar (C, axes (C, 2 ))
11541185 ci = firstindex (C, 1 )
11551186 ta = t (alpha)
1156- for i in axes (A, 1 )
1157- mul! (tmp, pB, view (pA, :, i))
1158- @views C[ci,:] .+ = t .(ta .* tmp)
1159- ci += 1
1187+ if isone (ta)
1188+ for i in axes (A, 1 )
1189+ mul! (tmp, pB, view (pA, :, i))
1190+ @views C[ci,:] .+ = t .(tmp)
1191+ ci += 1
1192+ end
1193+ else
1194+ for i in axes (A, 1 )
1195+ mul! (tmp, pB, view (pA, :, i))
1196+ @views C[ci,:] .+ = t .(ta .* tmp)
1197+ ci += 1
1198+ end
11601199 end
1161- C
1200+ return C
11621201end
11631202function _generic_matmatmul_generic! (C, A, B, alpha, beta)
11641203 if iszero (alpha) || isempty (A) || isempty (B)
0 commit comments