Skip to content

Commit 29ced9e

Browse files
dkarraschamilsted
andauthored
Stabilize MulAddMul strategically (#52439)
Co-authored-by: Ashley Milsted <[email protected]>
1 parent 999dde7 commit 29ced9e

File tree

6 files changed

+187
-105
lines changed

6 files changed

+187
-105
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -440,11 +440,16 @@ const BandedMatrix = Union{Bidiagonal,Diagonal,Tridiagonal,SymTridiagonal} # or
440440
const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
441441
const TriSym = Union{Tridiagonal,SymTridiagonal}
442442
const BiTri = Union{Bidiagonal,Tridiagonal}
443-
@inline _mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
444-
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
445-
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
446-
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
447-
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
443+
@inline _mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) =
444+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
445+
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) =
446+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
447+
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) =
448+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
449+
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
450+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
451+
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
452+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
448453

449454
lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
450455
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())

stdlib/LinearAlgebra/src/generic.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,74 @@ end
4949
end
5050
end
5151

52+
"""
53+
@stable_muladdmul
54+
55+
Replaces a function call, that has a `MulAddMul(alpha, beta)` constructor as an
56+
argument, with a branch over possible values of `isone(alpha)` and `iszero(beta)`
57+
and constructs `MulAddMul{isone(alpha), iszero(beta)}` explicitly in each branch.
58+
For example, 'f(x, y, MulAddMul(alpha, beta))` is transformed into
59+
```
60+
if isone(alpha)
61+
if iszero(beta)
62+
f(x, y, MulAddMul{true, true, typeof(alpha), typeof(beta)}(alpha, beta))
63+
else
64+
f(x, y, MulAddMul{true, false, typeof(alpha), typeof(beta)}(alpha, beta))
65+
end
66+
else
67+
if iszero(beta)
68+
f(x, y, MulAddMul{false, true, typeof(alpha), typeof(beta)}(alpha, beta))
69+
else
70+
f(x, y, MulAddMul{false, false, typeof(alpha), typeof(beta)}(alpha, beta))
71+
end
72+
end
73+
```
74+
This avoids the type instability of the `MulAddMul(alpha, beta)` constructor,
75+
which causes runtime dispatch in case alpha and zero are not constants.
76+
"""
77+
macro stable_muladdmul(expr)
78+
expr.head == :call || throw(ArgumentError("Can only handle function calls."))
79+
for (i, e) in enumerate(expr.args)
80+
e isa Expr || continue
81+
if e.head == :call && e.args[1] == :MulAddMul && length(e.args) == 3
82+
e.args[2] isa Symbol || continue
83+
e.args[3] isa Symbol || continue
84+
local asym = e.args[2]
85+
local bsym = e.args[3]
86+
87+
local e_sub11 = copy(expr)
88+
e_sub11.args[i] = :(MulAddMul{true, true, typeof($asym), typeof($bsym)}($asym, $bsym))
89+
90+
local e_sub10 = copy(expr)
91+
e_sub10.args[i] = :(MulAddMul{true, false, typeof($asym), typeof($bsym)}($asym, $bsym))
92+
93+
local e_sub01 = copy(expr)
94+
e_sub01.args[i] = :(MulAddMul{false, true, typeof($asym), typeof($bsym)}($asym, $bsym))
95+
96+
local e_sub00 = copy(expr)
97+
e_sub00.args[i] = :(MulAddMul{false, false, typeof($asym), typeof($bsym)}($asym, $bsym))
98+
99+
local e_out = quote
100+
if isone($asym)
101+
if iszero($bsym)
102+
$e_sub11
103+
else
104+
$e_sub10
105+
end
106+
else
107+
if iszero($bsym)
108+
$e_sub01
109+
else
110+
$e_sub00
111+
end
112+
end
113+
end
114+
return esc(e_out)
115+
end
116+
end
117+
throw(ArgumentError("No valid MulAddMul expression found."))
118+
end
119+
52120
MulAddMul() = MulAddMul{true,true,Bool,Bool}(true, false)
53121

54122
@inline (::MulAddMul{true})(x) = x

0 commit comments

Comments
 (0)