From 613935b5b2d26ab71655509e10bd63690255f588 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 14 Apr 2025 14:09:09 +0530 Subject: [PATCH 1/4] Branch on `Bool` `alpha` in scaling `mul!` --- src/generic.jl | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/generic.jl b/src/generic.jl index de37f081..1ac7f2bb 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -202,24 +202,37 @@ _lscale_add!(C::StridedArray, s::Number, X::StridedArray, alpha::Number, beta::N generic_mul!(C, s, X, alpha, beta) @inline function _lscale_add!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number) if axes(C) == axes(X) - if isone(alpha) - if iszero(beta) - @. C = s * X - else - @. C = s * X + C * beta - end - else - if iszero(beta) - @. C = s * X * alpha - else - @. C = s * X * alpha + C * beta - end - end + iszero(alpha) && return _rmul_or_fill!(C, beta) + _lscale_add_nonzeroalpha!(C, s, X, alpha, beta) else generic_mul!(C, s, X, alpha, beta) end return C end +function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number) + if isone(alpha) + if iszero(beta) + @. C = s * X + else + @. C = s * X + C * beta + end + else + if iszero(beta) + @. C = s * X * alpha + else + @. C = s * X * alpha + C * beta + end + end + C +end +function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Bool, beta::Number) + if iszero(beta) + @. C = s * X + else + @. C = s * X + C * beta + end + C +end @inline mul!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number) = _rscale_add!(C, X, s, alpha, beta) From 193f15b0d3de374e5c3e6f3c621da8528071f364 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 16 Apr 2025 11:07:05 +0530 Subject: [PATCH 2/4] Tests for scaling `mul!` --- test/generic.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/generic.jl b/test/generic.jl index 36dc5c94..3ff41157 100644 --- a/test/generic.jl +++ b/test/generic.jl @@ -857,4 +857,17 @@ end end end +@testset "scaling mul" begin + v = 1:4 + w = similar(v) + @test mul!(w, 2, v) == 2v + @test mul!(w, v, 2) == 2v + @test mul!(copy!(similar(v), v), 2, v, 1, 3) == 5v + @test mul!(copy!(similar(v), v), v, 2, 1, 3) == 5v + @test mul!(copy!(similar(v), v), 2, v, 2, 3) == 7v + @test mul!(copy!(similar(v), v), v, 2, 2, 3) == 7v + @test mul!(copy!(similar(v), v), 2, v, 2, 0) == 4v + @test mul!(copy!(similar(v), v), v, 2, 2, 0) == 4v +end + end # module TestGeneric From d7ff8af1a184833e3e6484bcd87d81b2e8845ece Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 18 Apr 2025 20:37:54 +0530 Subject: [PATCH 3/4] Set `alpha` to `true` if unused --- src/generic.jl | 8 +++----- test/generic.jl | 5 +++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/generic.jl b/src/generic.jl index 1ac7f2bb..20ca0d6d 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -211,11 +211,9 @@ _lscale_add!(C::StridedArray, s::Number, X::StridedArray, alpha::Number, beta::N end function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number) if isone(alpha) - if iszero(beta) - @. C = s * X - else - @. C = s * X + C * beta - end + # since alpha is unused, we might as well set to `true` to avoid recompiling + # the branch if a different type is used + _lscale_add_nonzeroalpha!(C, s, X, true, beta) else if iszero(beta) @. C = s * X * alpha diff --git a/test/generic.jl b/test/generic.jl index 3ff41157..e42feee0 100644 --- a/test/generic.jl +++ b/test/generic.jl @@ -862,6 +862,11 @@ end w = similar(v) @test mul!(w, 2, v) == 2v @test mul!(w, v, 2) == 2v + # 5-arg equivalent to the 3-arg method, but with non-Bool alpha + @test mul!(copy!(similar(v), v), 2, v, 1, 0) == 2v + @test mul!(copy!(similar(v), v), v, 2, 1, 0) == 2v + @test mul!(copy!(similar(v), v), 2, v, true, 1) == 3v + @test mul!(copy!(similar(v), v), v, 2, true, 1) == 3v @test mul!(copy!(similar(v), v), 2, v, 1, 3) == 5v @test mul!(copy!(similar(v), v), v, 2, 1, 3) == 5v @test mul!(copy!(similar(v), v), 2, v, 2, 3) == 7v From b937faef30990ee3fa27ca9bf622a36d77d15e7b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 18 Apr 2025 21:59:36 +0530 Subject: [PATCH 4/4] Reuse code in `_rscale_add!` --- src/generic.jl | 24 +++++++++++++----------- test/generic.jl | 4 ++++ 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/generic.jl b/src/generic.jl index 20ca0d6d..de3b195c 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -212,7 +212,7 @@ end function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number) if isone(alpha) # since alpha is unused, we might as well set to `true` to avoid recompiling - # the branch if a different type is used + # the branch if an `alpha` of a different type is used _lscale_add_nonzeroalpha!(C, s, X, true, beta) else if iszero(beta) @@ -239,24 +239,26 @@ _rscale_add!(C::StridedArray, X::StridedArray, s::Number, alpha::Number, beta::N @inline function _rscale_add!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number) if axes(C) == axes(X) if isone(alpha) - if iszero(beta) - @. C = X * s - else - @. C = X * s + C * beta - end + # since alpha is unused, we might as well ignore it in this branch. + # This avoids recompiling the branch if an `alpha` of a different type is used + _rscale_add_alphaisone!(C, X, s, beta) else s_alpha = s * alpha - if iszero(beta) - @. C = X * s_alpha - else - @. C = X * s_alpha + C * beta - end + _rscale_add_alphaisone!(C, X, s_alpha, beta) end else generic_mul!(C, X, s, alpha, beta) end return C end +function _rscale_add_alphaisone!(C::AbstractArray, X::AbstractArray, s::Number, beta::Number) + if iszero(beta) + @. C = X * s + else + @. C = X * s + C * beta + end + C +end # For better performance when input and output are the same array # See https://github.com/JuliaLang/julia/issues/8415#issuecomment-56608729 diff --git a/test/generic.jl b/test/generic.jl index e42feee0..11194a60 100644 --- a/test/generic.jl +++ b/test/generic.jl @@ -865,8 +865,12 @@ end # 5-arg equivalent to the 3-arg method, but with non-Bool alpha @test mul!(copy!(similar(v), v), 2, v, 1, 0) == 2v @test mul!(copy!(similar(v), v), v, 2, 1, 0) == 2v + # 5-arg tests with alpha::Bool @test mul!(copy!(similar(v), v), 2, v, true, 1) == 3v @test mul!(copy!(similar(v), v), v, 2, true, 1) == 3v + @test mul!(copy!(similar(v), v), 2, v, false, 2) == 2v + @test mul!(copy!(similar(v), v), v, 2, false, 2) == 2v + # 5-arg tests @test mul!(copy!(similar(v), v), 2, v, 1, 3) == 5v @test mul!(copy!(similar(v), v), v, 2, 1, 3) == 5v @test mul!(copy!(similar(v), v), 2, v, 2, 3) == 7v