@@ -131,45 +131,50 @@ if VERSION >= v"1.12-"
131131 end
132132end
133133
134- LinearAlgebra. generic_matmatmul! (C:: oneStridedMatrix , tA, tB, A:: oneStridedVecOrMat , B:: oneStridedVecOrMat , _add:: MulAddMul = MulAddMul ()) =
135- LinearAlgebra. generic_matmatmul! (C, tA, tB, A, B, _add. alpha, _add. beta)
136- function LinearAlgebra. generic_matmatmul! (C:: oneStridedMatrix , tA, tB, A:: oneStridedVecOrMat , B:: oneStridedVecOrMat , a:: Number , b:: Number )
134+ LinearAlgebra. generic_matmatmul! (
135+ C:: oneStridedVecOrMat , tA, tB, A:: oneStridedVecOrMat ,
136+ B:: oneStridedVecOrMat , _add:: MulAddMul ,
137+ ) = LinearAlgebra. generic_matmatmul! (C, tA, tB, A, B, _add. alpha, _add. beta)
138+ function LinearAlgebra. generic_matmatmul! (
139+ C:: oneStridedVecOrMat , tA, tB, A:: oneStridedVecOrMat ,
140+ B:: oneStridedVecOrMat , alpha:: Number , beta:: Number ,
141+ )
137142 T = eltype (C)
138- alpha, beta = promote (a, b, zero (T))
139143 mA, nA = size (A, tA == ' N' ? 1 : 2 ), size (A, tA == ' N' ? 2 : 1 )
140144 mB, nB = size (B, tB == ' N' ? 1 : 2 ), size (B, tB == ' N' ? 2 : 1 )
141- if nA != mB
142- throw (DimensionMismatch (" A has dimensions ($mA ,$nA ) but B has dimensions ($mB ,$nB )" ))
143- end
144145
145- if C === A || B === C
146- throw (ArgumentError (" output matrix must not be aliased with input matrix" ))
147- end
146+ nA != mB && throw (DimensionMismatch (
147+ " A has dimensions ($mA ,$nA ) but B has dimensions ($mB ,$nB )" ))
148+ (C === A || B === C) && throw (ArgumentError (
149+ " output matrix must not be aliased with input matrix" ))
148150
149151 if mA == 0 || nA == 0 || nB == 0
150- if size (C) != (mA, nB)
151- throw (DimensionMismatch (" C has dimensions $(size (C)) , should have ($mA ,$nB )" ))
152- end
152+ size (C) != (mA, nB) && throw (DimensionMismatch (
153+ " C has dimensions $(size (C)) , should have ($mA ,$nB )" ))
153154 return LinearAlgebra. rmul! (C, 0 )
154155 end
155156
156- if all (in ((' N' , ' T' , ' C' )), (tA, tB))
157- if T <: Union{onemklFloat, onemklComplex, onemklHalf} && eltype (A) == eltype (B) == T
158- return gemm! (tA, tB, alpha, A, B, beta, C)
159- end
160- end
157+ T = eltype (C)
158+
161159 if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
162160 # TODO : should the gemm part above be included in this branch?
163- if (tA == ' S' || tA == ' s' ) && tB == ' N'
164- return symm! (' L' , tA == ' S' ? ' U' : ' L' , alpha, A, B, beta, C)
161+ α, β = T (alpha), T (beta)
162+ if (
163+ all (in ((' N' , ' T' , ' C' )), (tA, tB)) && T <: Union{onemklFloat, onemklComplex, onemklHalf} &&
164+ A isa oneStridedArray{T} && B isa oneStridedArray{T}
165+ )
166+ return gemm! (tA, tB, α, A, B, β, C)
167+ elseif (tA == ' S' || tA == ' s' ) && tB == ' N'
168+ return symm! (' L' , tA == ' S' ? ' U' : ' L' , α, A, B, β, C)
165169 elseif (tB == ' S' || tB == ' s' ) && tA == ' N'
166- return symm! (' R' , tB == ' S' ? ' U' : ' L' , alpha , B, A, beta , C)
170+ return symm! (' R' , tB == ' S' ? ' U' : ' L' , α , B, A, β , C)
167171 elseif (tA == ' H' || tA == ' h' ) && tB == ' N'
168- return hemm! (' L' , tA == ' H' ? ' U' : ' L' , alpha , A, B, beta , C)
172+ return hemm! (' L' , tA == ' H' ? ' U' : ' L' , α , A, B, β , C)
169173 elseif (tB == ' H' || tB == ' h' ) && tA == ' N'
170- return hemm! (' R' , tB == ' H' ? ' U' : ' L' , alpha , B, A, beta , C)
174+ return hemm! (' R' , tB == ' H' ? ' U' : ' L' , α , B, A, β , C)
171175 end
172176 end
177+
173178 GPUArrays. generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
174179end
175180
0 commit comments