diff --git a/src/matmul.jl b/src/matmul.jl index 682691b4..d3eabfda 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -605,7 +605,7 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo aA_11 = abs2(A[1,1]) fill!(UpperTriangular(C), zero(aA_11 + aA_11)) end - iszero(α) && return C + (iszero(α) || isempty(A)) && return C @inbounds if !conjugate if aat for k ∈ 1:n, j ∈ 1:m @@ -1075,7 +1075,8 @@ function __generic_matvecmul!(::typeof(identity), C::AbstractVector, A::Abstract elseif length(B) == 0 C[i] = zero(eltype(C)) else - C[i] = zero(A[i]*B[1] + A[i]*B[1]) + ci = @stable_muladdmul MulAddMul(alpha,false)(A[i]*B[1]) + C[i] = zero(ci + ci) end end if !iszero(alpha) @@ -1147,12 +1148,12 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta) for j in axes(C, 2) B_1j = B[b1, j] for i in axes(C, 1) - C_ij = A[i, a1] * B_1j + C_ij = @stable_muladdmul MulAddMul(alpha, false)(A[i, a1] * B_1j) C[i,j] = zero(C_ij + C_ij) end end end - iszero(alpha) && return C + (iszero(alpha) || isempty(A) || isempty(B)) && return C @inbounds for n in axes(B, 2), k in axes(B, 1) # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha) Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n]) @@ -1167,7 +1168,7 @@ function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta) t = _wrapperop(A) pB = parent(B) pA = parent(A) - if (!iszero(beta) || isempty(A) || isempty(B)) # return C*beta + if (!iszero(beta) || isempty(A) || isempty(B)) _rmul_or_fill!(C, beta) else # iszero(beta) && A and B are non-empty a1 = firstindex(pA, 1) @@ -1175,13 +1176,13 @@ function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta) for j in axes(C, 2) tB_1j = t(pB[j, b1]) for i in axes(C, 1) - C_ij = t(pA[a1, i]) * tB_1j + C_ij = @stable_muladdmul MulAddMul(alpha, false)(t(pA[a1, i]) * tB_1j) C[i,j] = zero(C_ij + C_ij) end end end - iszero(alpha) && return C - tmp = similar(C, axes(C, 2)) + (iszero(alpha) || isempty(A) || isempty(B)) && return C + tmp = similar(C, promote_op(matprod, typeof(first(A)), typeof(first(B))), axes(C, 2)) ci = firstindex(C, 1) ta = t(alpha) if isone(ta) @@ -1434,7 +1435,7 @@ mat_vec_scalar(A::StridedMaybeAdjOrTransMat, x::StridedVector, γ) = _mat_vec_sc mat_vec_scalar(A::AdjOrTransAbsVec, x::StridedVector, γ) = (A * x) * γ function _mat_vec_scalar(A, x, γ) - T = promote_type(eltype(A), eltype(x), typeof(γ)) + T = promote_op(*, promote_op(matprod, eltype(A), eltype(x)), typeof(γ)) C = similar(A, T, axes(A,1)) mul!(C, A, x, γ, false) end @@ -1444,7 +1445,7 @@ mat_mat_scalar(A::StridedMaybeAdjOrTransMat, B::StridedMaybeAdjOrTransMat, γ) = _mat_mat_scalar(A, B, γ) function _mat_mat_scalar(A, B, γ) - T = promote_type(eltype(A), eltype(B), typeof(γ)) + T = promote_op(*, promote_op(matprod, eltype(A), eltype(B)), typeof(γ)) C = similar(A, T, axes(A,1), axes(B,2)) mul!(C, A, B, γ, false) end diff --git a/test/testhelpers/Furlongs.jl b/test/testhelpers/Furlongs.jl index 3ddf42bf..97e49a05 100644 --- a/test/testhelpers/Furlongs.jl +++ b/test/testhelpers/Furlongs.jl @@ -76,6 +76,11 @@ end for op in (:(==), :(!=), :<, :<=, :isless, :isequal) @eval $op(x::Furlong{p}, y::Furlong{p}) where {p} = $op(x.val, y.val)::Bool end +for op in (:(==), :isequal) + @eval $op(x::Furlong{p}, y::Furlong{q}) where {p,q} = false + @eval $op(x::Furlong, y::Number) = $op(x, convert(Furlong, y)) + @eval $op(x::Number, y::Furlong) = $op(y, x) +end for (f,op) in ((:_plus,:+),(:_minus,:-),(:_times,:*),(:_div,://)) @eval function $f(v::T, ::Furlong{p}, ::Union{Furlong{q},Val{q}}) where {T,p,q} s = $op(p, q) diff --git a/test/unitful.jl b/test/unitful.jl index 95feed97..cbea80b3 100644 --- a/test/unitful.jl +++ b/test/unitful.jl @@ -202,4 +202,21 @@ end @test (C \ b)::Vector{<:Furlong{0}} == (D \ b)::Vector{<:Furlong{0}} == Furlong{0}.([5, -12]) end +@testset "unitful 3-arg *" begin + for n in (2, 3, 5) + λ = 5 + A = rand(-10:10, n, n) + b = rand(-10:10, n) + λu = Furlong{1}(λ) + Au = Furlong{1}.(A) + bu = Furlong{1}.(b) + @test Furlong{3}.(A*A*λ) == Au*Au*λu + @test Furlong{3}.(A'*A*λ) == Au'*Au*λu + @test Furlong{3}.(A*A'*λ) == Au*Au'*λu + @test Furlong{3}.(A'*A'*λ) == Au'*Au'*λu + @test Furlong{3}.(A*b*λ) == Au*bu*λu + @test Furlong{3}.(A'*b*λ) == Au'*bu*λu + end +end + end # module TestUnitfulLinAlg