Skip to content

Commit e700433

Browse files
authored
Avoid constructing MulAddMuls (#623)
1 parent 75d7b6a commit e700433

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

src/blas/highlevel.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,14 @@ LinearAlgebra.generic_trimatdiv!(
106106
isunitc, A, C === B ? C : copyto!(C, B))
107107

108108
# GEMV
109-
109+
# legacy method
110+
LinearAlgebra.generic_matvecmul!(
111+
Y::ROCVector, tA::AbstractChar, A::StridedROCMatrix, B::StridedROCVector,
112+
_add::MulAddMul
113+
) = LinearAlgebra.generic_matvecmul!(Y, tA, A, B, _add.alpha, _add.beta)
110114
function LinearAlgebra.generic_matvecmul!(
111115
Y::ROCVector, tA::AbstractChar, A::StridedROCMatrix, B::StridedROCVector,
112-
_add::MulAddMul,
116+
alpha::Number, beta::Number,
113117
)
114118
mA, nA = tA == 'N' ? size(A) : reverse(size(A))
115119

@@ -122,7 +126,6 @@ function LinearAlgebra.generic_matvecmul!(
122126
nA == 0 && return rmul!(Y, 0)
123127

124128
T = eltype(Y)
125-
alpha, beta = _add.alpha, _add.beta
126129
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
127130
α, β = T(alpha), T(beta)
128131
if T <: ROCBLASFloat && eltype(A) == eltype(B) == T
@@ -135,19 +138,22 @@ function LinearAlgebra.generic_matvecmul!(
135138
end
136139
end
137140
end
138-
LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, MulAddMul(alpha, beta))
141+
LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, alpha, beta)
139142
end
140143

141144
#
142145
# BLAS 3
143146
#
144-
145-
function LinearAlgebra.generic_matmatmul!(
147+
# legacy method
148+
LinearAlgebra.generic_matmatmul!(
146149
C::StridedROCVecOrMat, tA, tB, A::StridedROCVecOrMat,
147150
B::StridedROCVecOrMat, _add::MulAddMul,
151+
) = LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
152+
function LinearAlgebra.generic_matmatmul!(
153+
C::StridedROCVecOrMat, tA, tB, A::StridedROCVecOrMat,
154+
B::StridedROCVecOrMat, alpha::Number, beta::Number,
148155
)
149156
T = eltype(C)
150-
alpha, beta = _add.alpha, _add.beta
151157
mA, nA = size(A, tA == 'N' ? 1 : 2), size(A, tA == 'N' ? 2 : 1)
152158
mB, nB = size(B, tB == 'N' ? 1 : 2), size(B, tB == 'N' ? 2 : 1)
153159

0 commit comments

Comments
 (0)