@@ -599,11 +599,21 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
599599 throw (DimensionMismatch (lazy " output matrix has size: $(size(C)), but should have size $((mA, mA))" ))
600600 end
601601
602- _rmul_or_fill! (C, β)
602+ if (! iszero (β) || isempty (A)) # return C*beta
603+ _rmul_or_fill! (C, β)
604+ else # iszero(β) && A and B are non-empty
605+ a1 = firstindex (A, 1 )
606+ a2 = firstindex (A, 2 )
607+ for j in axes (C, 2 ), i in axes (C, 1 )
608+ z1 = zero (A[i, a2]* A[a1, j] + A[i, a2]* A[a1, j])
609+ C[i,j] = convert (promote_type (typeof (z1), eltype (C)), z1)
610+ end
611+ end
612+ iszero (α) && return C
603613 @inbounds if ! conjugate
604614 if aat
605615 for k ∈ 1 : n, j ∈ 1 : m
606- αA_jk = A[j, k] * α
616+ αA_jk = @stable_muladdmul MulAddMul (α, false )( A[j, k])
607617 for i ∈ 1 : j
608618 C[i, j] += A[i, k] * αA_jk
609619 end
@@ -614,17 +624,17 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
614624 for k ∈ 2 : m
615625 temp += A[k, i] * A[k, j]
616626 end
617- C[i, j] += temp * α
627+ C[i, j] += @stable_muladdmul MulAddMul (α, false )(temp)
618628 end
619629 end
620630 else
621631 if aat
622632 for k ∈ 1 : n, j ∈ 1 : m
623- αA_jk_bar = conj (A[j, k]) * α
633+ αA_jk_bar = @stable_muladdmul MulAddMul (α, false )( conj (A[j, k]))
624634 for i ∈ 1 : j- 1
625635 C[i, j] += A[i, k] * αA_jk_bar
626636 end
627- C[j, j] += abs2 (A[j, k]) * α
637+ C[j, j] += @stable_muladdmul MulAddMul (α, false )( abs2 (A[j, k]))
628638 end
629639 else
630640 for j ∈ 1 : n
@@ -633,13 +643,13 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
633643 for k ∈ 2 : m
634644 temp += conj (A[k, i]) * A[k, j]
635645 end
636- C[i, j] += temp * α
646+ C[i, j] += @stable_muladdmul MulAddMul (α, false )(temp)
637647 end
638648 temp = abs2 (A[1 , j])
639649 for k ∈ 2 : m
640650 temp += abs2 (A[k, j])
641651 end
642- C[j, j] += temp * α
652+ C[j, j] += @stable_muladdmul MulAddMul (α, false )(temp)
643653 end
644654 end
645655 end
@@ -1132,8 +1142,18 @@ __generic_matmatmul!(C, A, B, alpha, beta, ::Val{true}) = _generic_matmatmul_non
11321142__generic_matmatmul! (C, A, B, alpha, beta, :: Val{false} ) = _generic_matmatmul_generic! (C, A, B, alpha, beta)
11331143
11341144function _generic_matmatmul_nonadjtrans! (C, A, B, alpha, beta)
1135- _rmul_or_fill! (C, beta)
1136- (iszero (alpha) || isempty (A) || isempty (B)) && return C
1145+ # _rmul_or_fill!(C, beta) spelled out more carefully to allow for zero-less eltypes
1146+ if (! iszero (beta) || isempty (A) || isempty (B)) # return C*beta
1147+ _rmul_or_fill! (C, beta)
1148+ else # iszero(beta) && A and B are non-empty
1149+ a1 = firstindex (A, 2 )
1150+ b1 = firstindex (B, 1 )
1151+ for j in axes (C, 2 ), i in axes (C, 1 )
1152+ z1 = zero (A[i, a1]* B[b1, j] + A[i, a1]* B[b1, j])
1153+ C[i,j] = convert (promote_type (typeof (z1), eltype (C)), z1)
1154+ end
1155+ end
1156+ iszero (alpha) && return C
11371157 @inbounds for n in axes (B, 2 ), k in axes (B, 1 )
11381158 # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
11391159 Balpha = @stable_muladdmul MulAddMul (alpha, false )(B[k,n])
@@ -1145,20 +1165,37 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta)
11451165 C
11461166end
11471167function _generic_matmatmul_adjtrans! (C, A, B, alpha, beta)
1148- _rmul_or_fill! (C, beta)
1149- (iszero (alpha) || isempty (A) || isempty (B)) && return C
1168+ if (! iszero (beta) || isempty (A) || isempty (B)) # return C*beta
1169+ _rmul_or_fill! (C, beta)
1170+ else # iszero(beta) && A and B are non-empty
1171+ a1 = firstindex (A, 2 )
1172+ b1 = firstindex (B, 1 )
1173+ for j in axes (C, 2 ), i in axes (C, 1 )
1174+ z1 = zero (A[i, a1]* B[b1, j] + A[i, a1]* B[b1, j])
1175+ C[i,j] = convert (promote_type (typeof (z1), eltype (C)), z1)
1176+ end
1177+ end
1178+ iszero (alpha) && return C
11501179 t = _wrapperop (A)
11511180 pB = parent (B)
11521181 pA = parent (A)
11531182 tmp = similar (C, axes (C, 2 ))
11541183 ci = firstindex (C, 1 )
11551184 ta = t (alpha)
1156- for i in axes (A, 1 )
1157- mul! (tmp, pB, view (pA, :, i))
1158- @views C[ci,:] .+ = t .(ta .* tmp)
1159- ci += 1
1185+ if isone (ta)
1186+ for i in axes (A, 1 )
1187+ mul! (tmp, pB, view (pA, :, i))
1188+ @views C[ci,:] .+ = t .(tmp)
1189+ ci += 1
1190+ end
1191+ else
1192+ for i in axes (A, 1 )
1193+ mul! (tmp, pB, view (pA, :, i))
1194+ @views C[ci,:] .+ = t .(ta .* tmp)
1195+ ci += 1
1196+ end
11601197 end
1161- C
1198+ return C
11621199end
11631200function _generic_matmatmul_generic! (C, A, B, alpha, beta)
11641201 if iszero (alpha) || isempty (A) || isempty (B)
0 commit comments