Skip to content

Commit b937fae

Browse files
committed
Reuse code in _rscale_add!
1 parent d7ff8af commit b937fae

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

src/generic.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ end
212212
function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number)
213213
if isone(alpha)
214214
# since alpha is unused, we might as well set to `true` to avoid recompiling
215-
# the branch if a different type is used
215+
# the branch if an `alpha` of a different type is used
216216
_lscale_add_nonzeroalpha!(C, s, X, true, beta)
217217
else
218218
if iszero(beta)
@@ -239,24 +239,26 @@ _rscale_add!(C::StridedArray, X::StridedArray, s::Number, alpha::Number, beta::N
239239
@inline function _rscale_add!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number)
240240
if axes(C) == axes(X)
241241
if isone(alpha)
242-
if iszero(beta)
243-
@. C = X * s
244-
else
245-
@. C = X * s + C * beta
246-
end
242+
# since alpha is unused, we might as well ignore it in this branch.
243+
# This avoids recompiling the branch if an `alpha` of a different type is used
244+
_rscale_add_alphaisone!(C, X, s, beta)
247245
else
248246
s_alpha = s * alpha
249-
if iszero(beta)
250-
@. C = X * s_alpha
251-
else
252-
@. C = X * s_alpha + C * beta
253-
end
247+
_rscale_add_alphaisone!(C, X, s_alpha, beta)
254248
end
255249
else
256250
generic_mul!(C, X, s, alpha, beta)
257251
end
258252
return C
259253
end
254+
function _rscale_add_alphaisone!(C::AbstractArray, X::AbstractArray, s::Number, beta::Number)
255+
if iszero(beta)
256+
@. C = X * s
257+
else
258+
@. C = X * s + C * beta
259+
end
260+
C
261+
end
260262

261263
# For better performance when input and output are the same array
262264
# See https://github.com/JuliaLang/julia/issues/8415#issuecomment-56608729

test/generic.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,8 +865,12 @@ end
865865
# 5-arg equivalent to the 3-arg method, but with non-Bool alpha
866866
@test mul!(copy!(similar(v), v), 2, v, 1, 0) == 2v
867867
@test mul!(copy!(similar(v), v), v, 2, 1, 0) == 2v
868+
# 5-arg tests with alpha::Bool
868869
@test mul!(copy!(similar(v), v), 2, v, true, 1) == 3v
869870
@test mul!(copy!(similar(v), v), v, 2, true, 1) == 3v
871+
@test mul!(copy!(similar(v), v), 2, v, false, 2) == 2v
872+
@test mul!(copy!(similar(v), v), v, 2, false, 2) == 2v
873+
# 5-arg tests
870874
@test mul!(copy!(similar(v), v), 2, v, 1, 3) == 5v
871875
@test mul!(copy!(similar(v), v), v, 2, 1, 3) == 5v
872876
@test mul!(copy!(similar(v), v), 2, v, 2, 3) == 7v

0 commit comments

Comments
 (0)