diff --git a/src/matmul.jl b/src/matmul.jl index d618bcfe..311ddfcf 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -599,11 +599,17 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))")) end - _rmul_or_fill!(C, β) + if (!iszero(β) || isempty(A)) # return C*beta + _rmul_or_fill!(C, β) + else # iszero(β) && A is non-empty + aA_11 = abs2(A[1,1]) + fill!(UpperTriangular(C), zero(aA_11 + aA_11)) + end + iszero(α) && return C @inbounds if !conjugate if aat for k ∈ 1:n, j ∈ 1:m - αA_jk = A[j, k] * α + αA_jk = @stable_muladdmul MulAddMul(α, false)(A[j, k]) for i ∈ 1:j C[i, j] += A[i, k] * αA_jk end @@ -614,17 +620,17 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo for k ∈ 2:m temp += A[k, i] * A[k, j] end - C[i, j] += temp * α + C[i, j] += @stable_muladdmul MulAddMul(α, false)(temp) end end else if aat for k ∈ 1:n, j ∈ 1:m - αA_jk_bar = conj(A[j, k]) * α + αA_jk_bar = @stable_muladdmul MulAddMul(α, false)(conj(A[j, k])) for i ∈ 1:j-1 C[i, j] += A[i, k] * αA_jk_bar end - C[j, j] += abs2(A[j, k]) * α + C[j, j] += @stable_muladdmul MulAddMul(α, false)(abs2(A[j, k])) end else for j ∈ 1:n @@ -633,13 +639,13 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo for k ∈ 2:m temp += conj(A[k, i]) * A[k, j] end - C[i, j] += temp * α + C[i, j] += @stable_muladdmul MulAddMul(α, false)(temp) end temp = abs2(A[1, j]) for k ∈ 2:m temp += abs2(A[k, j]) end - C[j, j] += temp * α + C[j, j] += @stable_muladdmul MulAddMul(α, false)(temp) end end end @@ -1132,8 +1138,21 @@ __generic_matmatmul!(C, A, B, alpha, beta, ::Val{true}) = _generic_matmatmul_non __generic_matmatmul!(C, A, B, alpha, beta, ::Val{false}) = _generic_matmatmul_generic!(C, A, B, alpha, beta) function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta) - _rmul_or_fill!(C, beta) - (iszero(alpha) || isempty(A) || isempty(B)) && return C + # _rmul_or_fill!(C, beta) spelled out more carefully to allow for zero-less eltypes + if (!iszero(beta) || isempty(A) || isempty(B)) # return C*beta + _rmul_or_fill!(C, beta) + else # iszero(beta) && A and B are non-empty + a1 = firstindex(A, 2) + b1 = firstindex(B, 1) + for j in axes(C, 2) + B_1j = B[b1, j] + for i in axes(C, 1) + C_ij = A[i, a1] * B_1j + C[i,j] = zero(C_ij + C_ij) + end + end + end + iszero(alpha) && return C @inbounds for n in axes(B, 2), k in axes(B, 1) # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha) Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n]) @@ -1145,20 +1164,40 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta) C end function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta) - _rmul_or_fill!(C, beta) - (iszero(alpha) || isempty(A) || isempty(B)) && return C t = _wrapperop(A) pB = parent(B) pA = parent(A) + if (!iszero(beta) || isempty(A) || isempty(B)) # return C*beta + _rmul_or_fill!(C, beta) + else # iszero(beta) && A and B are non-empty + a1 = firstindex(pA, 1) + b1 = firstindex(pB, 2) + for j in axes(C, 2) + tB_1j = t(pB[j, b1]) + for i in axes(C, 1) + C_ij = t(pA[a1, i]) * tB_1j + C[i,j] = zero(C_ij + C_ij) + end + end + end + iszero(alpha) && return C tmp = similar(C, axes(C, 2)) ci = firstindex(C, 1) ta = t(alpha) - for i in axes(A, 1) - mul!(tmp, pB, view(pA, :, i)) - @views C[ci,:] .+= t.(ta .* tmp) - ci += 1 + if isone(ta) + for i in axes(A, 1) + mul!(tmp, pB, view(pA, :, i)) + @views C[ci,:] .+= t.(tmp) + ci += 1 + end + else + for i in axes(A, 1) + mul!(tmp, pB, view(pA, :, i)) + @views C[ci,:] .+= t.(ta .* tmp) + ci += 1 + end end - C + return C end function _generic_matmatmul_generic!(C, A, B, alpha, beta) if iszero(alpha) || isempty(A) || isempty(B) diff --git a/test/matmul.jl b/test/matmul.jl index 1fcf2009..067f49cb 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -1241,4 +1241,29 @@ end @test C1 ≈ C2 end +@testset "matmul with zero-less types" begin + struct Mod <: Real + val::Int + modulo::Int + Mod(x::Int, y::Int) = new(x % y, y) + end + + Base.:+(x::Mod, y::Mod) = Mod(x.val + y.val, x.modulo) + Base.:*(x::Mod, y::Mod) = Mod(x.val * y.val, x.modulo) + Base.zero(x::Mod) = Mod(0, x.modulo) + + m = Mod.(rand(0:19, 5, 0), 20) + @test_throws MethodError m * copy(m') + for n in (2, 3, 5) + A = rand(0:19, n, n) + M = Mod.(A, 20) + @test M * M == Mod.(A * A, 20) + @test M' * M == Mod.(A' * A, 20) + @test M * M' == Mod.(A * A', 20) + @test M' * M' == Mod.(A' * A', 20) + @test M * M[:, 1] == Mod.(A * A[:, 1], 20) + @test M' * M[:, 1] == Mod.(A' * A[:, 1], 20) + end +end + end # module TestMatmul