diff --git a/src/generic.jl b/src/generic.jl index de37f081..de3b195c 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -202,24 +202,35 @@ _lscale_add!(C::StridedArray, s::Number, X::StridedArray, alpha::Number, beta::N generic_mul!(C, s, X, alpha, beta) @inline function _lscale_add!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number) if axes(C) == axes(X) - if isone(alpha) - if iszero(beta) - @. C = s * X - else - @. C = s * X + C * beta - end - else - if iszero(beta) - @. C = s * X * alpha - else - @. C = s * X * alpha + C * beta - end - end + iszero(alpha) && return _rmul_or_fill!(C, beta) + _lscale_add_nonzeroalpha!(C, s, X, alpha, beta) else generic_mul!(C, s, X, alpha, beta) end return C end +function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number) + if isone(alpha) + # since alpha is unused, we might as well set to `true` to avoid recompiling + # the branch if an `alpha` of a different type is used + _lscale_add_nonzeroalpha!(C, s, X, true, beta) + else + if iszero(beta) + @. C = s * X * alpha + else + @. C = s * X * alpha + C * beta + end + end + C +end +function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Bool, beta::Number) + if iszero(beta) + @. C = s * X + else + @. C = s * X + C * beta + end + C +end @inline mul!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number) = _rscale_add!(C, X, s, alpha, beta) @@ -228,24 +239,26 @@ _rscale_add!(C::StridedArray, X::StridedArray, s::Number, alpha::Number, beta::N @inline function _rscale_add!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number) if axes(C) == axes(X) if isone(alpha) - if iszero(beta) - @. C = X * s - else - @. C = X * s + C * beta - end + # since alpha is unused, we might as well ignore it in this branch. + # This avoids recompiling the branch if an `alpha` of a different type is used + _rscale_add_alphaisone!(C, X, s, beta) else s_alpha = s * alpha - if iszero(beta) - @. C = X * s_alpha - else - @. C = X * s_alpha + C * beta - end + _rscale_add_alphaisone!(C, X, s_alpha, beta) end else generic_mul!(C, X, s, alpha, beta) end return C end +function _rscale_add_alphaisone!(C::AbstractArray, X::AbstractArray, s::Number, beta::Number) + if iszero(beta) + @. C = X * s + else + @. C = X * s + C * beta + end + C +end # For better performance when input and output are the same array # See https://github.com/JuliaLang/julia/issues/8415#issuecomment-56608729 diff --git a/test/generic.jl b/test/generic.jl index 36dc5c94..11194a60 100644 --- a/test/generic.jl +++ b/test/generic.jl @@ -857,4 +857,26 @@ end end end +@testset "scaling mul" begin + v = 1:4 + w = similar(v) + @test mul!(w, 2, v) == 2v + @test mul!(w, v, 2) == 2v + # 5-arg equivalent to the 3-arg method, but with non-Bool alpha + @test mul!(copy!(similar(v), v), 2, v, 1, 0) == 2v + @test mul!(copy!(similar(v), v), v, 2, 1, 0) == 2v + # 5-arg tests with alpha::Bool + @test mul!(copy!(similar(v), v), 2, v, true, 1) == 3v + @test mul!(copy!(similar(v), v), v, 2, true, 1) == 3v + @test mul!(copy!(similar(v), v), 2, v, false, 2) == 2v + @test mul!(copy!(similar(v), v), v, 2, false, 2) == 2v + # 5-arg tests + @test mul!(copy!(similar(v), v), 2, v, 1, 3) == 5v + @test mul!(copy!(similar(v), v), v, 2, 1, 3) == 5v + @test mul!(copy!(similar(v), v), 2, v, 2, 3) == 7v + @test mul!(copy!(similar(v), v), v, 2, 2, 3) == 7v + @test mul!(copy!(similar(v), v), 2, v, 2, 0) == 4v + @test mul!(copy!(similar(v), v), v, 2, 2, 0) == 4v +end + end # module TestGeneric