Skip to content

Commit 613935b

Browse files
committed
Branch on Bool alpha in scaling mul!
1 parent 830ea2f commit 613935b

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

src/generic.jl

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -202,24 +202,37 @@ _lscale_add!(C::StridedArray, s::Number, X::StridedArray, alpha::Number, beta::N
202202
generic_mul!(C, s, X, alpha, beta)
203203
@inline function _lscale_add!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number)
204204
if axes(C) == axes(X)
205-
if isone(alpha)
206-
if iszero(beta)
207-
@. C = s * X
208-
else
209-
@. C = s * X + C * beta
210-
end
211-
else
212-
if iszero(beta)
213-
@. C = s * X * alpha
214-
else
215-
@. C = s * X * alpha + C * beta
216-
end
217-
end
205+
iszero(alpha) && return _rmul_or_fill!(C, beta)
206+
_lscale_add_nonzeroalpha!(C, s, X, alpha, beta)
218207
else
219208
generic_mul!(C, s, X, alpha, beta)
220209
end
221210
return C
222211
end
212+
function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number)
213+
if isone(alpha)
214+
if iszero(beta)
215+
@. C = s * X
216+
else
217+
@. C = s * X + C * beta
218+
end
219+
else
220+
if iszero(beta)
221+
@. C = s * X * alpha
222+
else
223+
@. C = s * X * alpha + C * beta
224+
end
225+
end
226+
C
227+
end
228+
function _lscale_add_nonzeroalpha!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Bool, beta::Number)
229+
if iszero(beta)
230+
@. C = s * X
231+
else
232+
@. C = s * X + C * beta
233+
end
234+
C
235+
end
223236
@inline mul!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number) =
224237
_rscale_add!(C, X, s, alpha, beta)
225238

0 commit comments

Comments
 (0)