diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index d87d5e9b63..07c584860c 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -400,8 +400,8 @@ for (fname, fname_64, elty) in ((:cublasSrotm_v2, :cublasSrotm_v2_64, :Float32), end ## rotmg -for (fname, elty) in ((:cublasSrotmg_v2, :Float32), - (:cublasDrotmg_v2, :Float64)) +for (fname, fname_64, elty) in ((:cublasSrotmg_v2, :cublasSrotmg_v2_64, :Float32), + (:cublasDrotmg_v2, :cublasSrotmg_v2_64, :Float64)) @eval begin function rotmg!(d1::$elty, d2::$elty, @@ -412,7 +412,11 @@ for (fname, elty) in ((:cublasSrotmg_v2, :Float32), ref_d2 = CuRef(d2) ref_x = CuRef(x) ref_y = CuRef(y) - $fname(handle(), ref_d1, ref_d2, ref_x, ref_y, param) + if CUBLAS.version() >= v"12.0" + $fname_64(handle(), ref_d1, ref_d2, ref_x, ref_y, param) + else + $fname(handle(), ref_d1, ref_d2, ref_x, ref_y, param) + end ref_d1[], ref_d2[], ref_x[], ref_y[], param end end @@ -545,8 +549,8 @@ end for (fname, fname_64, eltyin, eltyout) in ( (:cublasDgemvBatched, :cublasDgemvBatched_64, :Float64, :Float64), (:cublasSgemvBatched, :cublasSgemvBatched_64, :Float32, :Float32), - (:cublasHSHgemvBatched, :cublasHSHgemvBatched, :Float16, :Float16), - (:cublasHSSgemvBatched, :cublasHSSgemvBatched, :Float16, :Float32), + (:cublasHSHgemvBatched, :cublasHSHgemvBatched_64, :Float16, :Float16), + (:cublasHSSgemvBatched, :cublasHSSgemvBatched_64, :Float16, :Float32), (:cublasZgemvBatched, :cublasZgemvBatched_64, :ComplexF64, :ComplexF64), (:cublasCgemvBatched, :cublasCgemvBatched_64, :ComplexF32, :ComplexF32), ) @@ -594,8 +598,8 @@ end for (fname, fname_64, eltyin, eltyout) in ( (:cublasDgemvStridedBatched, :cublasDgemvStridedBatched_64, :Float64, :Float64), (:cublasSgemvStridedBatched, :cublasSgemvStridedBatched_64, :Float32, :Float32), - (:cublasHSHgemvStridedBatched, :cublasHSHgemvStridedBatched, :Float16, :Float16), - (:cublasHSSgemvStridedBatched, :cublasHSSgemvStridedBatched, :Float16, :Float32), + (:cublasHSHgemvStridedBatched, :cublasHSHgemvStridedBatched_64, :Float16, :Float16), + (:cublasHSSgemvStridedBatched, :cublasHSSgemvStridedBatched_64, :Float16, :Float32), (:cublasZgemvStridedBatched, :cublasZgemvStridedBatched_64, :ComplexF64, :ComplexF64), (:cublasCgemvStridedBatched, :cublasCgemvStridedBatched_64, :ComplexF32, :ComplexF32), ) @@ -1116,7 +1120,7 @@ end ## (GE) general matrix-matrix multiplication for (fname, fname_64, elty) in ((:cublasDgemm_v2, :cublasDgemm_v2_64, :Float64), (:cublasSgemm_v2, :cublasSgemm_v2_64, :Float32), - (:cublasHgemm, :cublasHgemm, :Float16), + (:cublasHgemm, :cublasHgemm_64, :Float16), (:cublasZgemm_v2, :cublasZgemm_v2_64, :ComplexF64), (:cublasCgemm_v2, :cublasCgemm_v2_64, :ComplexF32)) @eval begin @@ -1527,7 +1531,7 @@ end ## (GE) general matrix-matrix multiplication batched for (fname, fname_64, elty) in ((:cublasDgemmBatched, :cublasDgemmBatched_64, :Float64), (:cublasSgemmBatched, :cublasSgemmBatched_64, :Float32), - (:cublasHgemmBatched, :cublasHgemmBatched, :Float16), + (:cublasHgemmBatched, :cublasHgemmBatched_64, :Float16), (:cublasZgemmBatched, :cublasZgemmBatched_64, :ComplexF64), (:cublasCgemmBatched, :cublasCgemmBatched_64, :ComplexF32)) @eval begin @@ -1594,7 +1598,7 @@ end ## (GE) general matrix-matrix multiplication strided batched for (fname, fname_64, elty) in ((:cublasDgemmStridedBatched, :cublasDgemmStridedBatched_64, :Float64), (:cublasSgemmStridedBatched, :cublasSgemmStridedBatched_64, :Float32), - (:cublasHgemmStridedBatched, :cublasHgemmStridedBatched, :Float16), + (:cublasHgemmStridedBatched, :cublasHgemmStridedBatched_64, :Float16), (:cublasZgemmStridedBatched, :cublasZgemmStridedBatched_64, :ComplexF64), (:cublasCgemmStridedBatched, :cublasCgemmStridedBatched_64, :ComplexF32)) @eval begin @@ -1945,11 +1949,11 @@ function her2k(uplo::Char, trans::Char, end ## (TR) Triangular matrix and vector multiplication and solution -for (mmname, smname, elty) in - ((:cublasDtrmm_v2,:cublasDtrsm_v2,:Float64), - (:cublasStrmm_v2,:cublasStrsm_v2,:Float32), - (:cublasZtrmm_v2,:cublasZtrsm_v2,:ComplexF64), - (:cublasCtrmm_v2,:cublasCtrsm_v2,:ComplexF32)) +for (mmname, mmname_64, elty) in + ((:cublasDtrmm_v2, :cublasDtrmm_v2_64, :Float64), + (:cublasStrmm_v2, :cublasStrmm_v2_64, :Float32), + (:cublasZtrmm_v2, :cublasZtrmm_v2_64, :ComplexF64), + (:cublasCtrmm_v2, :cublasCtrmm_v2_64, :ComplexF32)) @eval begin # Note: CUBLAS differs from BLAS API for trmm # BLAS: inplace modification of B @@ -1972,10 +1976,22 @@ for (mmname, smname, elty) in lda = max(1,stride(A,2)) ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) - $mmname(handle(), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, C, ldc) + if CUBLAS.version() >= v"12.0" + $mmname_64(handle(), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, C, ldc) + else + $mmname(handle(), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, C, ldc) + end C end + end +end +for (smname, smname_64, elty) in + ((:cublasDtrsm_v2, :cublasDtrsm_v2_64, :Float64), + (:cublasStrsm_v2, :cublasStrsm_v2_64, :Float32), + (:cublasZtrsm_v2, :cublasZtrsm_v2_64, :ComplexF64), + (:cublasCtrsm_v2, :cublasCtrsm_v2_64, :ComplexF32)) + @eval begin function trsm!(side::Char, uplo::Char, transa::Char, @@ -1990,7 +2006,11 @@ for (mmname, smname, elty) in if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trsm!")) end lda = max(1,stride(A,2)) ldb = max(1,stride(B,2)) - $smname(handle(), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb) + if CUBLAS.version() >= v"12.0" + $smname_64(handle(), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb) + else + $smname(handle(), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb) + end B end end