Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,24 +202,35 @@ _lscale_add!(C::StridedArray, s::Number, X::StridedArray, alpha::Number, beta::N
generic_mul!(C, s, X, alpha, beta)
@inline function _lscale_add!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number)
if axes(C) == axes(X)
if isone(alpha)
if iszero(beta)
@. C = s * X
else
@. C = s * X + C * beta
end
else
if iszero(beta)
@. C = s * X * alpha
else
@. C = s * X * alpha + C * beta
end
end
iszero(alpha) && return _rmul_or_fill!(C, beta)
_lscale_add_nonzeroalpha!(C, s, X, alpha, beta)
else
generic_mul!(C, s, X, alpha, beta)
end
return C
end
function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number)
if isone(alpha)
# since alpha is unused, we might as well set to `true` to avoid recompiling
# the branch if an `alpha` of a different type is used
_lscale_add_nonzeroalpha!(C, s, X, true, beta)
else
if iszero(beta)
@. C = s * X * alpha
else
@. C = s * X * alpha + C * beta
end
end
C
end
function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Bool, beta::Number)
if iszero(beta)
@. C = s * X
else
@. C = s * X + C * beta
end
C
end
@inline mul!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number) =
_rscale_add!(C, X, s, alpha, beta)

Expand All @@ -228,24 +239,26 @@ _rscale_add!(C::StridedArray, X::StridedArray, s::Number, alpha::Number, beta::N
@inline function _rscale_add!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number)
if axes(C) == axes(X)
if isone(alpha)
if iszero(beta)
@. C = X * s
else
@. C = X * s + C * beta
end
# since alpha is unused, we might as well ignore it in this branch.
# This avoids recompiling the branch if an `alpha` of a different type is used
_rscale_add_alphaisone!(C, X, s, beta)
else
s_alpha = s * alpha
if iszero(beta)
@. C = X * s_alpha
else
@. C = X * s_alpha + C * beta
end
_rscale_add_alphaisone!(C, X, s_alpha, beta)
end
else
generic_mul!(C, X, s, alpha, beta)
end
return C
end
function _rscale_add_alphaisone!(C::AbstractArray, X::AbstractArray, s::Number, beta::Number)
if iszero(beta)
@. C = X * s
else
@. C = X * s + C * beta
end
C
end

# For better performance when input and output are the same array
# See https://github.com/JuliaLang/julia/issues/8415#issuecomment-56608729
Expand Down
22 changes: 22 additions & 0 deletions test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -857,4 +857,26 @@ end
end
end

@testset "scaling mul" begin
v = 1:4
w = similar(v)
@test mul!(w, 2, v) == 2v
@test mul!(w, v, 2) == 2v
# 5-arg equivalent to the 3-arg method, but with non-Bool alpha
@test mul!(copy!(similar(v), v), 2, v, 1, 0) == 2v
@test mul!(copy!(similar(v), v), v, 2, 1, 0) == 2v
# 5-arg tests with alpha::Bool
@test mul!(copy!(similar(v), v), 2, v, true, 1) == 3v
@test mul!(copy!(similar(v), v), v, 2, true, 1) == 3v
@test mul!(copy!(similar(v), v), 2, v, false, 2) == 2v
@test mul!(copy!(similar(v), v), v, 2, false, 2) == 2v
# 5-arg tests
@test mul!(copy!(similar(v), v), 2, v, 1, 3) == 5v
@test mul!(copy!(similar(v), v), v, 2, 1, 3) == 5v
@test mul!(copy!(similar(v), v), 2, v, 2, 3) == 7v
@test mul!(copy!(similar(v), v), v, 2, 2, 3) == 7v
@test mul!(copy!(similar(v), v), 2, v, 2, 0) == 4v
@test mul!(copy!(similar(v), v), v, 2, 2, 0) == 4v
end

end # module TestGeneric