Skip to content

Commit 6ef625a

Browse files
authored
[CUBLAS] Update wrapppers to use the ILP64 API (#2845)
* Update wrapppers to use the ILP64 API * Update wrappers.jl * Add missing Julia wrappers * rotmg doesn't have ILP64 symbols
1 parent d6ad9c3 commit 6ef625a

File tree

3 files changed

+316
-14
lines changed

3 files changed

+316
-14
lines changed

lib/cublas/libcublas.jl

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5914,6 +5914,17 @@ end
59145914
incy::Cint, batchCount::Cint)::cublasStatus_t
59155915
end
59165916

5917+
@checked function cublasHSHgemvBatched_64(handle, trans, m, n, alpha, Aarray, lda, xarray,
5918+
incx, beta, yarray, incy, batchCount)
5919+
initialize_context()
5920+
@ccall libcublas.cublasHSHgemvBatched_64(handle::cublasHandle_t, trans::cublasOperation_t,
5921+
m::Int64, n::Int64, alpha::CuRef{Cfloat},
5922+
Aarray::CuPtr{Ptr{Float16}}, lda::Int64,
5923+
xarray::CuPtr{Ptr{Float16}}, incx::Int64,
5924+
beta::CuRef{Cfloat}, yarray::CuPtr{Ptr{Float16}},
5925+
incy::Int64, batchCount::Int64)::cublasStatus_t
5926+
end
5927+
59175928
@checked function cublasHSSgemvBatched(handle, trans, m, n, alpha, Aarray, lda, xarray,
59185929
incx, beta, yarray, incy, batchCount)
59195930
initialize_context()
@@ -5925,6 +5936,17 @@ end
59255936
incy::Cint, batchCount::Cint)::cublasStatus_t
59265937
end
59275938

5939+
@checked function cublasHSSgemvBatched_64(handle, trans, m, n, alpha, Aarray, lda, xarray,
5940+
incx, beta, yarray, incy, batchCount)
5941+
initialize_context()
5942+
@ccall libcublas.cublasHSSgemvBatched_64(handle::cublasHandle_t, trans::cublasOperation_t,
5943+
m::Int64, n::Int64, alpha::CuRef{Cfloat},
5944+
Aarray::CuPtr{Ptr{Float16}}, lda::Int64,
5945+
xarray::CuPtr{Ptr{Float16}}, incx::Int64,
5946+
beta::CuRef{Cfloat}, yarray::CuPtr{Ptr{Cfloat}},
5947+
incy::Int64, batchCount::Int64)::cublasStatus_t
5948+
end
5949+
59285950
@checked function cublasTSTgemvBatched(handle, trans, m, n, alpha, Aarray, lda, xarray,
59295951
incx, beta, yarray, incy, batchCount)
59305952
initialize_context()
@@ -5936,6 +5958,17 @@ end
59365958
incy::Cint, batchCount::Cint)::cublasStatus_t
59375959
end
59385960

5961+
@checked function cublasTSTgemvBatched_64(handle, trans, m, n, alpha, Aarray, lda, xarray,
5962+
incx, beta, yarray, incy, batchCount)
5963+
initialize_context()
5964+
@ccall libcublas.cublasTSTgemvBatched_64(handle::cublasHandle_t, trans::cublasOperation_t,
5965+
m::Int64, n::Int64, alpha::Ptr{Cfloat},
5966+
Aarray::Ptr{Ptr{BFloat16}}, lda::Int64,
5967+
xarray::Ptr{Ptr{BFloat16}}, incx::Int64,
5968+
beta::Ptr{Cfloat}, yarray::Ptr{Ptr{BFloat16}},
5969+
incy::Int64, batchCount::Int64)::cublasStatus_t
5970+
end
5971+
59395972
@checked function cublasTSSgemvBatched(handle, trans, m, n, alpha, Aarray, lda, xarray,
59405973
incx, beta, yarray, incy, batchCount)
59415974
initialize_context()
@@ -5947,6 +5980,17 @@ end
59475980
incy::Cint, batchCount::Cint)::cublasStatus_t
59485981
end
59495982

5983+
@checked function cublasTSSgemvBatched_64(handle, trans, m, n, alpha, Aarray, lda, xarray,
5984+
incx, beta, yarray, incy, batchCount)
5985+
initialize_context()
5986+
@ccall libcublas.cublasTSSgemvBatched_64(handle::cublasHandle_t, trans::cublasOperation_t,
5987+
m::Int64, n::Int64, alpha::Ptr{Cfloat},
5988+
Aarray::Ptr{Ptr{BFloat16}}, lda::Int64,
5989+
xarray::Ptr{Ptr{BFloat16}}, incx::Int64,
5990+
beta::Ptr{Cfloat}, yarray::Ptr{Ptr{Cfloat}},
5991+
incy::Int64, batchCount::Int64)::cublasStatus_t
5992+
end
5993+
59505994
@checked function cublasHSHgemvStridedBatched(handle, trans, m, n, alpha, A, lda, strideA,
59515995
x, incx, stridex, beta, y, incy, stridey,
59525996
batchCount)
@@ -5962,6 +6006,21 @@ end
59626006
batchCount::Cint)::cublasStatus_t
59636007
end
59646008

6009+
@checked function cublasHSHgemvStridedBatched_64(handle, trans, m, n, alpha, A, lda, strideA,
6010+
x, incx, stridex, beta, y, incy, stridey,
6011+
batchCount)
6012+
initialize_context()
6013+
@ccall libcublas.cublasHSHgemvStridedBatched_64(handle::cublasHandle_t,
6014+
trans::cublasOperation_t, m::Int64, n::Int64,
6015+
alpha::CuRef{Cfloat}, A::CuPtr{Float16},
6016+
lda::Int64, strideA::Clonglong,
6017+
x::CuPtr{Float16}, incx::Int64,
6018+
stridex::Clonglong, beta::CuRef{Cfloat},
6019+
y::CuPtr{Float16}, incy::Int64,
6020+
stridey::Clonglong,
6021+
batchCount::Int64)::cublasStatus_t
6022+
end
6023+
59656024
@checked function cublasHSSgemvStridedBatched(handle, trans, m, n, alpha, A, lda, strideA,
59666025
x, incx, stridex, beta, y, incy, stridey,
59676026
batchCount)
@@ -5977,6 +6036,21 @@ end
59776036
batchCount::Cint)::cublasStatus_t
59786037
end
59796038

6039+
@checked function cublasHSSgemvStridedBatched_64(handle, trans, m, n, alpha, A, lda, strideA,
6040+
x, incx, stridex, beta, y, incy, stridey,
6041+
batchCount)
6042+
initialize_context()
6043+
@ccall libcublas.cublasHSSgemvStridedBatched_64(handle::cublasHandle_t,
6044+
trans::cublasOperation_t, m::Int64, n::Int64,
6045+
alpha::CuRef{Cfloat}, A::CuPtr{Float16},
6046+
lda::Int64, strideA::Clonglong,
6047+
x::CuPtr{Float16}, incx::Int64,
6048+
stridex::Clonglong, beta::CuRef{Cfloat},
6049+
y::CuPtr{Cfloat}, incy::Int64,
6050+
stridey::Clonglong,
6051+
batchCount::Int64)::cublasStatus_t
6052+
end
6053+
59806054
@checked function cublasTSTgemvStridedBatched(handle, trans, m, n, alpha, A, lda, strideA,
59816055
x, incx, stridex, beta, y, incy, stridey,
59826056
batchCount)
@@ -5992,6 +6066,21 @@ end
59926066
batchCount::Cint)::cublasStatus_t
59936067
end
59946068

6069+
@checked function cublasTSTgemvStridedBatched_64(handle, trans, m, n, alpha, A, lda, strideA,
6070+
x, incx, stridex, beta, y, incy, stridey,
6071+
batchCount)
6072+
initialize_context()
6073+
@ccall libcublas.cublasTSTgemvStridedBatched_64(handle::cublasHandle_t,
6074+
trans::cublasOperation_t, m::Int64, n::Int64,
6075+
alpha::CuRef{Cfloat}, A::CuPtr{BFloat16},
6076+
lda::Int64, strideA::Clonglong,
6077+
x::CuPtr{BFloat16}, incx::Int64,
6078+
stridex::Clonglong, beta::CuRef{Cfloat},
6079+
y::CuPtr{BFloat16}, incy::Int64,
6080+
stridey::Clonglong,
6081+
batchCount::Int64)::cublasStatus_t
6082+
end
6083+
59956084
@checked function cublasTSSgemvStridedBatched(handle, trans, m, n, alpha, A, lda, strideA,
59966085
x, incx, stridex, beta, y, incy, stridey,
59976086
batchCount)
@@ -6007,6 +6096,21 @@ end
60076096
batchCount::Cint)::cublasStatus_t
60086097
end
60096098

6099+
@checked function cublasTSSgemvStridedBatched_64(handle, trans, m, n, alpha, A, lda, strideA,
6100+
x, incx, stridex, beta, y, incy, stridey,
6101+
batchCount)
6102+
initialize_context()
6103+
@ccall libcublas.cublasTSSgemvStridedBatched_64(handle::cublasHandle_t,
6104+
trans::cublasOperation_t, m::Int64, n::Int64,
6105+
alpha::CuRef{Cfloat}, A::CuPtr{BFloat16},
6106+
lda::Int64, strideA::Clonglong,
6107+
x::CuPtr{BFloat16}, incx::Int64,
6108+
stridex::Clonglong, beta::CuRef{Cfloat},
6109+
y::CuPtr{Cfloat}, incy::Int64,
6110+
stridey::Clonglong,
6111+
batchCount::Int64)::cublasStatus_t
6112+
end
6113+
60106114
@checked function cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta,
60116115
C, ldc)
60126116
initialize_context()
@@ -6017,6 +6121,16 @@ end
60176121
C::Ptr{Float16}, ldc::Cint)::cublasStatus_t
60186122
end
60196123

6124+
@checked function cublasHgemm_64(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta,
6125+
C, ldc)
6126+
initialize_context()
6127+
@ccall libcublas.cublasHgemm_64(handle::cublasHandle_t, transa::cublasOperation_t,
6128+
transb::cublasOperation_t, m::Int64, n::Int64, k::Int64,
6129+
alpha::Ptr{Float16}, A::Ptr{Float16}, lda::Int64,
6130+
B::Ptr{Float16}, ldb::Int64, beta::Ptr{Float16},
6131+
C::Ptr{Float16}, ldc::Int64)::cublasStatus_t
6132+
end
6133+
60206134
@checked function cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda,
60216135
Barray, ldb, beta, Carray, ldc, batchCount)
60226136
initialize_context()
@@ -6030,6 +6144,19 @@ end
60306144
batchCount::Cint)::cublasStatus_t
60316145
end
60326146

6147+
@checked function cublasHgemmBatched_64(handle, transa, transb, m, n, k, alpha, Aarray, lda,
6148+
Barray, ldb, beta, Carray, ldc, batchCount)
6149+
initialize_context()
6150+
@ccall libcublas.cublasHgemmBatched_64(handle::cublasHandle_t, transa::cublasOperation_t,
6151+
transb::cublasOperation_t, m::Int64, n::Int64,
6152+
k::Int64, alpha::CuRef{Float16},
6153+
Aarray::CuPtr{Ptr{Float16}}, lda::Int64,
6154+
Barray::CuPtr{Ptr{Float16}}, ldb::Int64,
6155+
beta::CuRef{Float16},
6156+
Carray::CuPtr{Ptr{Float16}}, ldc::Int64,
6157+
batchCount::Int64)::cublasStatus_t
6158+
end
6159+
60336160
@checked function cublasHgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda,
60346161
strideA, B, ldb, strideB, beta, C, ldc, strideC,
60356162
batchCount)
@@ -6045,3 +6172,19 @@ end
60456172
ldc::Cint, strideC::Clonglong,
60466173
batchCount::Cint)::cublasStatus_t
60476174
end
6175+
6176+
@checked function cublasHgemmStridedBatched_64(handle, transa, transb, m, n, k, alpha, A, lda,
6177+
strideA, B, ldb, strideB, beta, C, ldc, strideC,
6178+
batchCount)
6179+
initialize_context()
6180+
@ccall libcublas.cublasHgemmStridedBatched_64(handle::cublasHandle_t,
6181+
transa::cublasOperation_t,
6182+
transb::cublasOperation_t, m::Int64, n::Int64,
6183+
k::Int64, alpha::CuRef{Float16},
6184+
A::CuPtr{Float16}, lda::Int64,
6185+
strideA::Clonglong, B::CuPtr{Float16},
6186+
ldb::Int64, strideB::Clonglong,
6187+
beta::CuRef{Float16}, C::CuPtr{Float16},
6188+
ldc::Int64, strideC::Clonglong,
6189+
batchCount::Int64)::cublasStatus_t
6190+
end

lib/cublas/wrappers.jl

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -545,8 +545,8 @@ end
545545
for (fname, fname_64, eltyin, eltyout) in (
546546
(:cublasDgemvBatched, :cublasDgemvBatched_64, :Float64, :Float64),
547547
(:cublasSgemvBatched, :cublasSgemvBatched_64, :Float32, :Float32),
548-
(:cublasHSHgemvBatched, :cublasHSHgemvBatched, :Float16, :Float16),
549-
(:cublasHSSgemvBatched, :cublasHSSgemvBatched, :Float16, :Float32),
548+
(:cublasHSHgemvBatched, :cublasHSHgemvBatched_64, :Float16, :Float16),
549+
(:cublasHSSgemvBatched, :cublasHSSgemvBatched_64, :Float16, :Float32),
550550
(:cublasZgemvBatched, :cublasZgemvBatched_64, :ComplexF64, :ComplexF64),
551551
(:cublasCgemvBatched, :cublasCgemvBatched_64, :ComplexF32, :ComplexF32),
552552
)
@@ -594,8 +594,8 @@ end
594594
for (fname, fname_64, eltyin, eltyout) in (
595595
(:cublasDgemvStridedBatched, :cublasDgemvStridedBatched_64, :Float64, :Float64),
596596
(:cublasSgemvStridedBatched, :cublasSgemvStridedBatched_64, :Float32, :Float32),
597-
(:cublasHSHgemvStridedBatched, :cublasHSHgemvStridedBatched, :Float16, :Float16),
598-
(:cublasHSSgemvStridedBatched, :cublasHSSgemvStridedBatched, :Float16, :Float32),
597+
(:cublasHSHgemvStridedBatched, :cublasHSHgemvStridedBatched_64, :Float16, :Float16),
598+
(:cublasHSSgemvStridedBatched, :cublasHSSgemvStridedBatched_64, :Float16, :Float32),
599599
(:cublasZgemvStridedBatched, :cublasZgemvStridedBatched_64, :ComplexF64, :ComplexF64),
600600
(:cublasCgemvStridedBatched, :cublasCgemvStridedBatched_64, :ComplexF32, :ComplexF32),
601601
)
@@ -1116,7 +1116,7 @@ end
11161116
## (GE) general matrix-matrix multiplication
11171117
for (fname, fname_64, elty) in ((:cublasDgemm_v2, :cublasDgemm_v2_64, :Float64),
11181118
(:cublasSgemm_v2, :cublasSgemm_v2_64, :Float32),
1119-
(:cublasHgemm, :cublasHgemm, :Float16),
1119+
(:cublasHgemm, :cublasHgemm_64, :Float16),
11201120
(:cublasZgemm_v2, :cublasZgemm_v2_64, :ComplexF64),
11211121
(:cublasCgemm_v2, :cublasCgemm_v2_64, :ComplexF32))
11221122
@eval begin
@@ -1527,7 +1527,7 @@ end
15271527
## (GE) general matrix-matrix multiplication batched
15281528
for (fname, fname_64, elty) in ((:cublasDgemmBatched, :cublasDgemmBatched_64, :Float64),
15291529
(:cublasSgemmBatched, :cublasSgemmBatched_64, :Float32),
1530-
(:cublasHgemmBatched, :cublasHgemmBatched, :Float16),
1530+
(:cublasHgemmBatched, :cublasHgemmBatched_64, :Float16),
15311531
(:cublasZgemmBatched, :cublasZgemmBatched_64, :ComplexF64),
15321532
(:cublasCgemmBatched, :cublasCgemmBatched_64, :ComplexF32))
15331533
@eval begin
@@ -1594,7 +1594,7 @@ end
15941594
## (GE) general matrix-matrix multiplication strided batched
15951595
for (fname, fname_64, elty) in ((:cublasDgemmStridedBatched, :cublasDgemmStridedBatched_64, :Float64),
15961596
(:cublasSgemmStridedBatched, :cublasSgemmStridedBatched_64, :Float32),
1597-
(:cublasHgemmStridedBatched, :cublasHgemmStridedBatched, :Float16),
1597+
(:cublasHgemmStridedBatched, :cublasHgemmStridedBatched_64, :Float16),
15981598
(:cublasZgemmStridedBatched, :cublasZgemmStridedBatched_64, :ComplexF64),
15991599
(:cublasCgemmStridedBatched, :cublasCgemmStridedBatched_64, :ComplexF32))
16001600
@eval begin
@@ -1945,11 +1945,11 @@ function her2k(uplo::Char, trans::Char,
19451945
end
19461946

19471947
## (TR) Triangular matrix and vector multiplication and solution
1948-
for (mmname, smname, elty) in
1949-
((:cublasDtrmm_v2,:cublasDtrsm_v2,:Float64),
1950-
(:cublasStrmm_v2,:cublasStrsm_v2,:Float32),
1951-
(:cublasZtrmm_v2,:cublasZtrsm_v2,:ComplexF64),
1952-
(:cublasCtrmm_v2,:cublasCtrsm_v2,:ComplexF32))
1948+
for (mmname, mmname_64, elty) in
1949+
((:cublasDtrmm_v2, :cublasDtrmm_v2_64, :Float64),
1950+
(:cublasStrmm_v2, :cublasStrmm_v2_64, :Float32),
1951+
(:cublasZtrmm_v2, :cublasZtrmm_v2_64, :ComplexF64),
1952+
(:cublasCtrmm_v2, :cublasCtrmm_v2_64, :ComplexF32))
19531953
@eval begin
19541954
# Note: CUBLAS differs from BLAS API for trmm
19551955
# BLAS: inplace modification of B
@@ -1972,10 +1972,22 @@ for (mmname, smname, elty) in
19721972
lda = max(1,stride(A,2))
19731973
ldb = max(1,stride(B,2))
19741974
ldc = max(1,stride(C,2))
1975-
$mmname(handle(), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, C, ldc)
1975+
if CUBLAS.version() >= v"12.0"
1976+
$mmname_64(handle(), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, C, ldc)
1977+
else
1978+
$mmname(handle(), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, C, ldc)
1979+
end
19761980
C
19771981
end
1982+
end
1983+
end
19781984

1985+
for (smname, smname_64, elty) in
1986+
((:cublasDtrsm_v2, :cublasDtrsm_v2_64, :Float64),
1987+
(:cublasStrsm_v2, :cublasStrsm_v2_64, :Float32),
1988+
(:cublasZtrsm_v2, :cublasZtrsm_v2_64, :ComplexF64),
1989+
(:cublasCtrsm_v2, :cublasCtrsm_v2_64, :ComplexF32))
1990+
@eval begin
19791991
function trsm!(side::Char,
19801992
uplo::Char,
19811993
transa::Char,
@@ -1990,7 +2002,11 @@ for (mmname, smname, elty) in
19902002
if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trsm!")) end
19912003
lda = max(1,stride(A,2))
19922004
ldb = max(1,stride(B,2))
1993-
$smname(handle(), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb)
2005+
if CUBLAS.version() >= v"12.0"
2006+
$smname_64(handle(), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb)
2007+
else
2008+
$smname(handle(), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb)
2009+
end
19942010
B
19952011
end
19962012
end

0 commit comments

Comments
 (0)