Skip to content

Commit 415294a

Browse files
jishnubKristofferC
authored andcommitted
Call MulAddMul instead of multiplication in _generic_matmatmul! (#56089)
Fix https://github.com/JuliaLang/julia/issues/56085 by calling a newly created `MulAddMul` object that only wraps the `alpha` (with `beta` set to `false`). This avoids the explicit multiplication if `alpha` is known to be `isone`. (cherry picked from commit 0af99e6)
1 parent 9dda314 commit 415294a

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
869869
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
870870

871871
@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
872-
_add::MulAddMul) where {T,S,R}
872+
_add::MulAddMul{ais1}) where {T,S,R,ais1}
873873
AxM = axes(A, 1)
874874
AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
875875
BxK = axes(B, 1)
@@ -885,11 +885,13 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
885885
if BxN != CxN
886886
throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
887887
end
888+
_rmul_alpha = MulAddMul{ais1,true,typeof(_add.alpha),Bool}(_add.alpha,false)
888889
if isbitstype(R) && sizeof(R) 16 && !(A isa Adjoint || A isa Transpose)
889890
_rmul_or_fill!(C, _add.beta)
890891
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
891892
@inbounds for n in BxN, k in BxK
892-
Balpha = B[k,n]*_add.alpha
893+
# Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
894+
Balpha = _rmul_alpha(B[k,n])
893895
@simd for m in AxM
894896
C[m,n] = muladd(A[m,k], Balpha, C[m,n])
895897
end

stdlib/LinearAlgebra/test/matmul.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,4 +1107,22 @@ end
11071107
end
11081108
end
11091109

1110+
@testset "issue #56085" begin
1111+
struct Thing
1112+
data::Float64
1113+
end
1114+
1115+
Base.zero(::Type{Thing}) = Thing(0.)
1116+
Base.zero(::Thing) = Thing(0.)
1117+
Base.one(::Type{Thing}) = Thing(1.)
1118+
Base.one(::Thing) = Thing(1.)
1119+
Base.:+(t::Thing...) = +(getfield.(t, :data)...)
1120+
Base.:*(t::Thing...) = *(getfield.(t, :data)...)
1121+
1122+
M = Float64[1 2; 3 4]
1123+
A = Thing.(M)
1124+
1125+
@test A * A M * M
1126+
end
1127+
11101128
end # module TestMatmul

0 commit comments

Comments
 (0)