Skip to content

Commit 3e55c9d

Browse files
committed
Align with AMDGPU for generic_matmatmul! dispatch
1 parent 723b4d7 commit 3e55c9d

File tree

1 file changed

+28
-23
lines changed

1 file changed

+28
-23
lines changed

lib/mkl/linalg.jl

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -131,45 +131,50 @@ if VERSION >= v"1.12-"
131131
end
132132
end
133133

134-
LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStridedVecOrMat, B::oneStridedVecOrMat, _add::MulAddMul=MulAddMul()) =
135-
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
136-
function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStridedVecOrMat, B::oneStridedVecOrMat, a::Number, b::Number)
134+
LinearAlgebra.generic_matmatmul!(
135+
C::oneStridedVecOrMat, tA, tB, A::oneStridedVecOrMat,
136+
B::oneStridedVecOrMat, _add::MulAddMul,
137+
) = LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
138+
function LinearAlgebra.generic_matmatmul!(
139+
C::oneStridedVecOrMat, tA, tB, A::oneStridedVecOrMat,
140+
B::oneStridedVecOrMat, alpha::Number, beta::Number,
141+
)
137142
T = eltype(C)
138-
alpha, beta = promote(a, b, zero(T))
139143
mA, nA = size(A, tA == 'N' ? 1 : 2), size(A, tA == 'N' ? 2 : 1)
140144
mB, nB = size(B, tB == 'N' ? 1 : 2), size(B, tB == 'N' ? 2 : 1)
141-
if nA != mB
142-
throw(DimensionMismatch("A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
143-
end
144145

145-
if C === A || B === C
146-
throw(ArgumentError("output matrix must not be aliased with input matrix"))
147-
end
146+
nA != mB && throw(DimensionMismatch(
147+
"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
148+
(C === A || B === C) && throw(ArgumentError(
149+
"output matrix must not be aliased with input matrix"))
148150

149151
if mA == 0 || nA == 0 || nB == 0
150-
if size(C) != (mA, nB)
151-
throw(DimensionMismatch("C has dimensions $(size(C)), should have ($mA,$nB)"))
152-
end
152+
size(C) != (mA, nB) && throw(DimensionMismatch(
153+
"C has dimensions $(size(C)), should have ($mA,$nB)"))
153154
return LinearAlgebra.rmul!(C, 0)
154155
end
155156

156-
if all(in(('N', 'T', 'C')), (tA, tB))
157-
if T <: Union{onemklFloat, onemklComplex, onemklHalf} && eltype(A) == eltype(B) == T
158-
return gemm!(tA, tB, alpha, A, B, beta, C)
159-
end
160-
end
157+
T = eltype(C)
158+
161159
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
162160
# TODO: should the gemm part above be included in this branch?
163-
if (tA == 'S' || tA == 's') && tB == 'N'
164-
return symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C)
161+
α, β = T(alpha), T(beta)
162+
if (
163+
all(in(('N', 'T', 'C')), (tA, tB)) && T <: Union{onemklFloat, onemklComplex, onemklHalf} &&
164+
A isa oneStridedArray{T} && B isa oneStridedArray{T}
165+
)
166+
return gemm!(tA, tB, α, A, B, β, C)
167+
elseif (tA == 'S' || tA == 's') && tB == 'N'
168+
return symm!('L', tA == 'S' ? 'U' : 'L', α, A, B, β, C)
165169
elseif (tB == 'S' || tB == 's') && tA == 'N'
166-
return symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C)
170+
return symm!('R', tB == 'S' ? 'U' : 'L', α, B, A, β, C)
167171
elseif (tA == 'H' || tA == 'h') && tB == 'N'
168-
return hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C)
172+
return hemm!('L', tA == 'H' ? 'U' : 'L', α, A, B, β, C)
169173
elseif (tB == 'H' || tB == 'h') && tA == 'N'
170-
return hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C)
174+
return hemm!('R', tB == 'H' ? 'U' : 'L', α, B, A, β, C)
171175
end
172176
end
177+
173178
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
174179
end
175180

0 commit comments

Comments
 (0)