|
107 | 107 | ## copy a triangular part of a matrix to another matrix |
108 | 108 |
|
109 | 109 | if isdefined(LinearAlgebra, :copytrito!) |
110 | | - function LinearAlgebra.copytrito!(B::AbstractGPUMatrix, A::AbstractGPUMatrix, uplo::AbstractChar) |
| 110 | + function LinearAlgebra.copytrito!(B::AbstractGPUMatrix{T}, A::AbstractGPUMatrix{T}, uplo::AbstractChar) where {T} |
111 | 111 | LinearAlgebra.BLAS.chkuplo(uplo) |
112 | 112 | m,n = size(A) |
113 | 113 | m1,n1 = size(B) |
@@ -376,6 +376,13 @@ function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::Abs |
376 | 376 | LinearAlgebra.@stable_muladdmul generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(a, b)) |
377 | 377 | end |
378 | 378 | end |
| 379 | +@static if VERSION ≥ v"1.12.0-rc" |
| 380 | + # we need to use the generic wrapper to avoid dispatch to the 2x2or3x3 method |
| 381 | + using LinearAlgebra: generic_matmatmul_wrapper!, BlasFlag |
| 382 | + function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::LinearAlgebra.BlasFlag.SyrkHerkGemm) where {T} |
| 383 | + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) |
| 384 | + end |
| 385 | +end |
379 | 386 |
|
380 | 387 | function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R} |
381 | 388 | if size(A,2) != size(B,1) |
|
0 commit comments