Skip to content
20 changes: 10 additions & 10 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,16 +407,16 @@ end

const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
const BiTri = Union{Bidiagonal,Tridiagonal}
@inline mul!(C::AbstractVector, A::BiTriSym, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractVector, A::BiTriSym, B::AbstractVector, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::AbstractMatrix, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))

function check_A_mul_B!_sizes(C, A, B)
mA, nA = size(A)
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -585,14 +585,14 @@ for Tri in (:UpperTriangular, :LowerTriangular)
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = iszero(β) ? nothing : diag(C)
data = mul!(C.data, D, A.data, α, β)
$Tri(_setdiag!(data, MulAddMul(α, β), D.diag, diag′))
$Tri(@stable_muladdmul _setdiag!(data, MulAddMul(α, β), D.diag, diag′))
end
@eval @inline mul!(C::$Tri, A::$Tri, D::Diagonal, α::Number, β::Number) = $Tri(mul!(C.data, A.data, D, α, β))
@eval @inline function mul!(C::$Tri, A::$UTri, D::Diagonal, α::Number, β::Number)
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = iszero(β) ? nothing : diag(C)
data = mul!(C.data, A.data, D, α, β)
$Tri(_setdiag!(data, MulAddMul(α, β), D.diag, diag′))
$Tri(@stable_muladdmul _setdiag!(data, MulAddMul(α, β), D.diag, diag′))
end
end

Expand Down
70 changes: 70 additions & 0 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,76 @@ end
end
end

"""
@stable_muladdmul

Replaces a function call, that has a `MulAddMul(alpha, beta)` constructor as an
argument, with a branch over possible values of `isone(alpha)` and `iszero(beta)`
and constructs `MulAddMul{isone(alpha), iszero(beta)}` explicitly in each branch.

For example, 'f(x, y, MulAddMul(alpha, beta))` is transformed into
```
if isone(alpha)
if iszero(beta)
f(x, y, MulAddMul{true, true, typeof(alpha), typeof(beta)}(alpha, beta))
else
f(x, y, MulAddMul{true, false, typeof(alpha), typeof(beta)}(alpha, beta))
end
else
if iszero(beta)
f(x, y, MulAddMul{false, true, typeof(alpha), typeof(beta)}(alpha, beta))
else
f(x, y, MulAddMul{false, false, typeof(alpha), typeof(beta)}(alpha, beta))
end
end
```

This avoids the type instability of the `MulAddMul(alpha, beta)` constructor,
which causes runtime dispatch in case alpha and zero are not constants.
"""
macro stable_muladdmul(expr)
expr.head == :call || throw(ArgumentError("Can only handle function calls."))
for (i, e) in enumerate(expr.args)
e isa Expr || continue
if e.head == :call && e.args[1] == :MulAddMul && length(e.args) == 3
e.args[2] isa Symbol || continue
e.args[3] isa Symbol || continue
local asym = e.args[2]
local bsym = e.args[3]

local e_sub11 = copy(expr)
e_sub11.args[i] = :(MulAddMul{true, true, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_sub10 = copy(expr)
e_sub10.args[i] = :(MulAddMul{true, false, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_sub01 = copy(expr)
e_sub01.args[i] = :(MulAddMul{false, true, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_sub00 = copy(expr)
e_sub00.args[i] = :(MulAddMul{false, false, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_out = quote
if isone($asym)
if iszero($bsym)
$e_sub11
else
$e_sub10
end
else
if iszero($bsym)
$e_sub01
else
$e_sub00
end
end
end
return esc(e_out)
end
end
throw(ArgumentError("No valid MulAddMul expression found."))
end

MulAddMul() = MulAddMul{true,true,Bool,Bool}(true, false)

@inline (::MulAddMul{true})(x) = x
Expand Down
Loading