From 21d9f2c1793388152649b6a98d9b3fed4d7714d4 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sat, 29 Nov 2025 12:56:11 +0100 Subject: [PATCH 1/4] Fix unitful 3-arg `*` --- src/matmul.jl | 21 +++++++++++---------- test/unitful.jl | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index 682691b4..d3eabfda 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -605,7 +605,7 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo aA_11 = abs2(A[1,1]) fill!(UpperTriangular(C), zero(aA_11 + aA_11)) end - iszero(α) && return C + (iszero(α) || isempty(A)) && return C @inbounds if !conjugate if aat for k ∈ 1:n, j ∈ 1:m @@ -1075,7 +1075,8 @@ function __generic_matvecmul!(::typeof(identity), C::AbstractVector, A::Abstract elseif length(B) == 0 C[i] = zero(eltype(C)) else - C[i] = zero(A[i]*B[1] + A[i]*B[1]) + ci = @stable_muladdmul MulAddMul(alpha,false)(A[i]*B[1]) + C[i] = zero(ci + ci) end end if !iszero(alpha) @@ -1147,12 +1148,12 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta) for j in axes(C, 2) B_1j = B[b1, j] for i in axes(C, 1) - C_ij = A[i, a1] * B_1j + C_ij = @stable_muladdmul MulAddMul(alpha, false)(A[i, a1] * B_1j) C[i,j] = zero(C_ij + C_ij) end end end - iszero(alpha) && return C + (iszero(alpha) || isempty(A) || isempty(B)) && return C @inbounds for n in axes(B, 2), k in axes(B, 1) # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha) Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n]) @@ -1167,7 +1168,7 @@ function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta) t = _wrapperop(A) pB = parent(B) pA = parent(A) - if (!iszero(beta) || isempty(A) || isempty(B)) # return C*beta + if (!iszero(beta) || isempty(A) || isempty(B)) _rmul_or_fill!(C, beta) else # iszero(beta) && A and B are non-empty a1 = firstindex(pA, 1) @@ -1175,13 +1176,13 @@ function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta) for j in axes(C, 2) tB_1j = t(pB[j, b1]) for i in axes(C, 1) - C_ij = t(pA[a1, i]) * tB_1j + C_ij = @stable_muladdmul MulAddMul(alpha, false)(t(pA[a1, i]) * tB_1j) C[i,j] = zero(C_ij + C_ij) end end end - iszero(alpha) && return C - tmp = similar(C, axes(C, 2)) + (iszero(alpha) || isempty(A) || isempty(B)) && return C + tmp = similar(C, promote_op(matprod, typeof(first(A)), typeof(first(B))), axes(C, 2)) ci = firstindex(C, 1) ta = t(alpha) if isone(ta) @@ -1434,7 +1435,7 @@ mat_vec_scalar(A::StridedMaybeAdjOrTransMat, x::StridedVector, γ) = _mat_vec_sc mat_vec_scalar(A::AdjOrTransAbsVec, x::StridedVector, γ) = (A * x) * γ function _mat_vec_scalar(A, x, γ) - T = promote_type(eltype(A), eltype(x), typeof(γ)) + T = promote_op(*, promote_op(matprod, eltype(A), eltype(x)), typeof(γ)) C = similar(A, T, axes(A,1)) mul!(C, A, x, γ, false) end @@ -1444,7 +1445,7 @@ mat_mat_scalar(A::StridedMaybeAdjOrTransMat, B::StridedMaybeAdjOrTransMat, γ) = _mat_mat_scalar(A, B, γ) function _mat_mat_scalar(A, B, γ) - T = promote_type(eltype(A), eltype(B), typeof(γ)) + T = promote_op(*, promote_op(matprod, eltype(A), eltype(B)), typeof(γ)) C = similar(A, T, axes(A,1), axes(B,2)) mul!(C, A, B, γ, false) end diff --git a/test/unitful.jl b/test/unitful.jl index 95feed97..59bf5a45 100644 --- a/test/unitful.jl +++ b/test/unitful.jl @@ -202,4 +202,21 @@ end @test (C \ b)::Vector{<:Furlong{0}} == (D \ b)::Vector{<:Furlong{0}} == Furlong{0}.([5, -12]) end +@testset "unitful 3-arg *" begin + for n in (2, 3, 5) + λ = 5 + A = randn(-10:10, n, n) + b = randn(-10:10, n) + λu = Furlong{1}(λ) + Au = Furlong{1}.(A) + bu = Furlong{1}.(b) + @test Furlong{3}.(A*A*λ) == Au*Au*λu + @test Furlong{3}.(A'*A*λ) == Au'*Au*λu + @test Furlong{3}.(A*A'*λ) == Au*Au'*λu + @test Furlong{3}.(A'*A'*λ) == Au'*Au'*λu + @test Furlong{3}.(A*b*λ) == Au*bu*λu + @test Furlong{3}.(A'*b*λ) == Au'*bu*λu + end +end + end # module TestUnitfulLinAlg From ca7da5441e1f409b9bd3020a17a2addc182d9cb4 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sat, 29 Nov 2025 13:28:12 +0100 Subject: [PATCH 2/4] fix typo --- test/unitful.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/unitful.jl b/test/unitful.jl index 59bf5a45..cbea80b3 100644 --- a/test/unitful.jl +++ b/test/unitful.jl @@ -205,8 +205,8 @@ end @testset "unitful 3-arg *" begin for n in (2, 3, 5) λ = 5 - A = randn(-10:10, n, n) - b = randn(-10:10, n) + A = rand(-10:10, n, n) + b = rand(-10:10, n) λu = Furlong{1}(λ) Au = Furlong{1}.(A) bu = Furlong{1}.(b) From 3acd59ad2da8833c39449b08e682002928edb591 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sat, 29 Nov 2025 14:28:39 +0100 Subject: [PATCH 3/4] add comparison between Furlong and Number --- test/testhelpers/Furlongs.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/testhelpers/Furlongs.jl b/test/testhelpers/Furlongs.jl index 3ddf42bf..a7db6405 100644 --- a/test/testhelpers/Furlongs.jl +++ b/test/testhelpers/Furlongs.jl @@ -76,6 +76,11 @@ end for op in (:(==), :(!=), :<, :<=, :isless, :isequal) @eval $op(x::Furlong{p}, y::Furlong{p}) where {p} = $op(x.val, y.val)::Bool end +for op in (:(==), :isequal) + @eval $op(x::Furlong{p}, y::Furlong{q}) where {p,q} = false + @eval $op(x::Furlong, y::Number) = $op(promote(x, y)...) + @eval $op(x::Number, y::Furlong) = $op(y, x) +end for (f,op) in ((:_plus,:+),(:_minus,:-),(:_times,:*),(:_div,://)) @eval function $f(v::T, ::Furlong{p}, ::Union{Furlong{q},Val{q}}) where {T,p,q} s = $op(p, q) From 8f27f1f2bb11a1fddce59e9b4136bab0a9e5b146 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sat, 29 Nov 2025 16:52:57 +0100 Subject: [PATCH 4/4] don't promote, convert instead --- test/testhelpers/Furlongs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/testhelpers/Furlongs.jl b/test/testhelpers/Furlongs.jl index a7db6405..97e49a05 100644 --- a/test/testhelpers/Furlongs.jl +++ b/test/testhelpers/Furlongs.jl @@ -78,7 +78,7 @@ for op in (:(==), :(!=), :<, :<=, :isless, :isequal) end for op in (:(==), :isequal) @eval $op(x::Furlong{p}, y::Furlong{q}) where {p,q} = false - @eval $op(x::Furlong, y::Number) = $op(promote(x, y)...) + @eval $op(x::Furlong, y::Number) = $op(x, convert(Furlong, y)) @eval $op(x::Number, y::Furlong) = $op(y, x) end for (f,op) in ((:_plus,:+),(:_minus,:-),(:_times,:*),(:_div,://))