338338
339339
340340# # matrix multiplication
341-
342- function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , a:: Number , b:: Number ) where {T,S,R}
341+ # legacy method
342+ generic_matmatmul! (C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , a:: Number , b:: Number ) where {T,S,R} =
343+ generic_matmatmul! (C, A, B, MulAddMul (a, b))
344+ function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , add:: MulAddMul ) where {T,S,R}
343345 if size (A,2 ) != size (B,1 )
344346 throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
345347 end
@@ -350,8 +352,6 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
350352 return fill! (C, zero (R))
351353 end
352354
353- add = MulAddMul (a, b)
354-
355355 gpu_call (C, A, B; name= " matmatmul!" ) do ctx, C, A, B
356356 idx = @linearidx C
357357 assume .(size (C) .> 0 )
@@ -372,42 +372,52 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
372372 C
373373end
374374
375+ @static if VERSION < v " 1.12.0-"
375376function LinearAlgebra. generic_matvecmul! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , _add:: MulAddMul = MulAddMul ())
376- generic_matmatmul! (C, wrap (A, tA), B, _add. alpha, _add . beta )
377+ generic_matmatmul! (C, wrap (A, tA), B, _add)
377378end
378379
379380function LinearAlgebra. generic_matmatmul! (C:: AbstractGPUVecOrMat , tA, tB, A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , _add:: MulAddMul = MulAddMul ())
380- generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add. alpha, _add. beta)
381+ generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
382+ end
383+ else
384+ function LinearAlgebra. generic_matvecmul! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , a:: Number , b:: Number )
385+ LinearAlgebra. @stable_muladdmul generic_matmatmul! (C, wrap (A, tA), B, MulAddMul (a, b))
386+ end
387+
388+ function LinearAlgebra. generic_matmatmul! (C:: AbstractGPUVecOrMat , tA, tB, A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , a:: Number , b:: Number )
389+ LinearAlgebra. @stable_muladdmul generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul (a, b))
381390end
391+ end
382392
383393if VERSION < v " 1.10.0-DEV.1365"
384394# catch other functions that are called by LinearAlgebra's mul!
385395function LinearAlgebra. gemv! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , a:: Number , b:: Number )
386- generic_matmatmul! (C, wrap (A, tA), B, a, b)
396+ generic_matmatmul! (C, wrap (A, tA), B, MulAddMul ( a, b) )
387397end
388398# disambiguation
389399function LinearAlgebra. gemv! (C:: AbstractGPUVector{T} , tA:: AbstractChar , A:: AbstractGPUMatrix{T} , B:: AbstractGPUVector{T} , a:: Number , b:: Number ) where {T<: LinearAlgebra.BlasFloat }
390- generic_matmatmul! (C, wrap (A, tA), B, a, b)
400+ generic_matmatmul! (C, wrap (A, tA), B, MulAddMul ( a, b) )
391401end
392402
393403LinearAlgebra. gemm_wrapper! (C:: AbstractGPUVecOrMat , tA:: AbstractChar , tB:: AbstractChar , A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , _add:: MulAddMul ) =
394- LinearAlgebra . generic_matmatmul! (C, tA, tB, A, B , _add)
404+ generic_matmatmul! (C, wrap (A, tA), wrap (B, tB) , _add)
395405# disambiguation
396406LinearAlgebra. gemm_wrapper! (C:: AbstractGPUVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar , A:: AbstractGPUVecOrMat{T} , B:: AbstractGPUVecOrMat{T} , _add:: MulAddMul ) where {T<: LinearAlgebra.BlasFloat } =
397- LinearAlgebra . generic_matmatmul! (C, tA, tB, A, B , _add)
407+ generic_matmatmul! (C, wrap (A, tA), wrap (B, tB) , _add)
398408
399409function LinearAlgebra. syrk_wrapper! (C:: AbstractGPUMatrix , tA:: AbstractChar , A:: AbstractGPUVecOrMat , _add:: MulAddMul = MulAddMul ())
400410 if tA == ' T'
401- LinearAlgebra . generic_matmatmul! (C, ' T ' , ' N ' , A , A, _add)
411+ generic_matmatmul! (C, wrap (A , ' T ' ) , A, _add)
402412 else # tA == 'N'
403- LinearAlgebra . generic_matmatmul! (C, ' N ' , ' T ' , A, A , _add)
413+ generic_matmatmul! (C, A, wrap ( A, ' T ' ) , _add)
404414 end
405415end
406416function LinearAlgebra. herk_wrapper! (C:: AbstractGPUMatrix , tA:: AbstractChar , A:: AbstractGPUVecOrMat , _add:: MulAddMul = MulAddMul ())
407417 if tA == ' C'
408- LinearAlgebra . generic_matmatmul! (C, ' C ' , ' N ' , A , A, _add)
418+ generic_matmatmul! (C, wrap (A , ' C ' ) , A, _add)
409419 else # tA == 'N'
410- LinearAlgebra . generic_matmatmul! (C, ' N ' , ' C ' , A, A , _add)
420+ generic_matmatmul! (C, A, wrap ( A, ' C ' ) , _add)
411421 end
412422end
413423end # VERSION
0 commit comments