Skip to content

Commit 3fdacb5

Browse files
dkarraschamilsted
andcommitted
Stabilize MulAddMul strategically
Co-authored-by: Ashley Milsted <[email protected]>
1 parent 39ccdb2 commit 3fdacb5

File tree

5 files changed

+85
-12
lines changed

5 files changed

+85
-12
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -428,11 +428,16 @@ const BandedMatrix = Union{Bidiagonal,Diagonal,Tridiagonal,SymTridiagonal} # or
428428
const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
429429
const TriSym = Union{Tridiagonal,SymTridiagonal}
430430
const BiTri = Union{Bidiagonal,Tridiagonal}
431-
@inline mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
432-
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
433-
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
434-
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
435-
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
431+
@inline mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) =
432+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
433+
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) =
434+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
435+
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) =
436+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
437+
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
438+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
439+
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
440+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
436441

437442
lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
438443
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())

stdlib/LinearAlgebra/src/generic.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,74 @@ end
4949
end
5050
end
5151

52+
"""
53+
@stable_muladdmul
54+
55+
Replaces a function call, that has a `MulAddMul(alpha, beta)` constructor as an
56+
argument, with a branch over possible values of `isone(alpha)` and `iszero(beta)`
57+
and constructs `MulAddMul{isone(alpha), iszero(beta)}` explicitly in each branch.
58+
For example, 'f(x, y, MulAddMul(alpha, beta))` is transformed into
59+
```
60+
if isone(alpha)
61+
if iszero(beta)
62+
f(x, y, MulAddMul{true, true, typeof(alpha), typeof(beta)}(alpha, beta))
63+
else
64+
f(x, y, MulAddMul{true, false, typeof(alpha), typeof(beta)}(alpha, beta))
65+
end
66+
else
67+
if iszero(beta)
68+
f(x, y, MulAddMul{false, true, typeof(alpha), typeof(beta)}(alpha, beta))
69+
else
70+
f(x, y, MulAddMul{false, false, typeof(alpha), typeof(beta)}(alpha, beta))
71+
end
72+
end
73+
```
74+
This avoids the type instability of the `MulAddMul(alpha, beta)` constructor,
75+
which causes runtime dispatch in case alpha and zero are not constants.
76+
"""
77+
macro stable_muladdmul(expr)
78+
expr.head == :call || throw(ArgumentError("Can only handle function calls."))
79+
for (i, e) in enumerate(expr.args)
80+
e isa Expr || continue
81+
if e.head == :call && e.args[1] == :MulAddMul && length(e.args) == 3
82+
e.args[2] isa Symbol || continue
83+
e.args[3] isa Symbol || continue
84+
local asym = e.args[2]
85+
local bsym = e.args[3]
86+
87+
local e_sub11 = copy(expr)
88+
e_sub11.args[i] = :(MulAddMul{true, true, typeof($asym), typeof($bsym)}($asym, $bsym))
89+
90+
local e_sub10 = copy(expr)
91+
e_sub10.args[i] = :(MulAddMul{true, false, typeof($asym), typeof($bsym)}($asym, $bsym))
92+
93+
local e_sub01 = copy(expr)
94+
e_sub01.args[i] = :(MulAddMul{false, true, typeof($asym), typeof($bsym)}($asym, $bsym))
95+
96+
local e_sub00 = copy(expr)
97+
e_sub00.args[i] = :(MulAddMul{false, false, typeof($asym), typeof($bsym)}($asym, $bsym))
98+
99+
local e_out = quote
100+
if isone($asym)
101+
if iszero($bsym)
102+
$e_sub11
103+
else
104+
$e_sub10
105+
end
106+
else
107+
if iszero($bsym)
108+
$e_sub01
109+
else
110+
$e_sub00
111+
end
112+
end
113+
end
114+
return esc(e_out)
115+
end
116+
end
117+
throw(ArgumentError("No valid MulAddMul expression found."))
118+
end
119+
52120
MulAddMul() = MulAddMul{true,true,Bool,Bool}(true, false)
53121

54122
@inline (::MulAddMul{true})(x) = x

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ end
6767

6868
@inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
6969
alpha::Number, beta::Number) =
70-
generic_matvecmul!(y, wrapper_char(A), _unwrap(A), x, MulAddMul(alpha, beta))
70+
@stable_muladdmul generic_matvecmul!(y, wrapper_char(A), _unwrap(A), x, MulAddMul(alpha, beta))
7171
# BLAS cases
7272
# equal eltypes
7373
@inline generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},

stdlib/LinearAlgebra/src/special.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ end
111111
mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix) = _mul!(C, A, B, MulAddMul())
112112
mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular) = _mul!(C, A, B, MulAddMul())
113113
mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix, alpha::Number, beta::Number) =
114-
_mul!(C, A, B, MulAddMul(alpha, beta))
114+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
115115
mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular, alpha::Number, beta::Number) =
116-
_mul!(C, A, B, MulAddMul(alpha, beta))
116+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
117117

118118
function *(H::UpperHessenberg, B::Bidiagonal)
119119
T = promote_op(matprod, eltype(H), eltype(B))

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,9 @@ end
472472
# Define `mul!` for (Unit){Upper,Lower}Triangular matrices times a number.
473473
# be permissive here and require compatibility later in _triscale!
474474
@inline mul!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) =
475-
_triscale!(A, B, C, MulAddMul(alpha, beta))
475+
@stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta))
476476
@inline mul!(A::AbstractTriangular, B::Number, C::AbstractTriangular, alpha::Number, beta::Number) =
477-
_triscale!(A, B, C, MulAddMul(alpha, beta))
477+
@stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta))
478478

479479
function _triscale!(A::UpperTriangular, B::UpperTriangular, c::Number, _add)
480480
n = checksquare(B)
@@ -732,7 +732,7 @@ for TC in (:AbstractVector, :AbstractMatrix)
732732
if isone(alpha) && iszero(beta)
733733
return mul!(C, A, B)
734734
else
735-
return generic_matvecmul!(C, 'N', A, B, MulAddMul(alpha, beta))
735+
return @stable_muladdmul generic_matvecmul!(C, 'N', A, B, MulAddMul(alpha, beta))
736736
end
737737
end
738738
end
@@ -744,7 +744,7 @@ for (TA, TB) in ((:AbstractTriangular, :AbstractMatrix),
744744
if isone(alpha) && iszero(beta)
745745
return mul!(C, A, B)
746746
else
747-
return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(alpha, beta))
747+
return @stable_muladdmul generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(alpha, beta))
748748
end
749749
end
750750
end

0 commit comments

Comments
 (0)