Skip to content

Commit 236643b

Browse files
authored
More tests for CUBLAS and a bugfix (#2659)
1 parent 6653404 commit 236643b

File tree

7 files changed

+186
-78
lines changed

7 files changed

+186
-78
lines changed

lib/cublas/wrappers.jl

Lines changed: 24 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ function juliaStorageType(T::Type{<:Complex}, ct::cublasComputeType_t)
7070
return Complex{Float32}
7171
elseif ct == CUBLAS_COMPUTE_64F || ct == CUBLAS_COMPUTE_64F_PEDANTIC
7272
return Complex{Float64}
73-
elseif ct == CUBLAS_COMPUTE_32I || ct == CUBLAS_COMPUTE_32I_PEDANTIC
74-
return Complex{Int32}
7573
else
7674
throw(ArgumentError("Julia type equivalent for compute type $ct does not exist!"))
7775
end
@@ -1174,14 +1172,10 @@ function gemmExComputeType(TA, TB, TC, m, k, n)
11741172

11751173
# gemmEx requires sm_50 or higher
11761174
cap = capability(device())
1177-
if cap < v"5"
1178-
return nothing
1179-
end
1175+
cap < v"5" && return nothing
11801176

11811177
# source: CUBLAS Features and Technical Specifications
1182-
if Float16 in sig && cap < v"5.3"
1183-
return nothing
1184-
end
1178+
Float16 in sig && cap < v"5.3" && return nothing
11851179

11861180
math_mode = CUDA.math_mode()
11871181
reduced_precision = CUDA.math_precision()
@@ -1192,15 +1186,10 @@ function gemmExComputeType(TA, TB, TC, m, k, n)
11921186
end
11931187

11941188
if sig === (Int8, Int32)
1195-
# starting with CUDA 11.2, this is unsupported (NVIDIA bug #3221266)
1196-
# TODO: might be fixed in a later version?
1197-
version() >= v"11.3.1" && return nothing
1198-
11991189
# Int32=Int8*Int8 requires m,n,k to be multiples of 4
12001190
# https://forums.developer.nvidia.com/t/cublasgemmex-cant-use-cuda-r-8i-compute-type-on-gtx1080/58100/2
1201-
if m%4 == 0 && n%4 == 0 && k%4 == 0
1202-
return math_mode==CUDA.PEDANTIC_MATH ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I
1203-
end
1191+
all_mod_4 = (m%4 == 0 && n%4 == 0 && k%4 == 0)
1192+
all_mod_4 && return math_mode==CUDA.PEDANTIC_MATH ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I
12041193
end
12051194

12061195
if math_mode == CUDA.FAST_MATH
@@ -1231,13 +1220,8 @@ function gemmExComputeType(TA, TB, TC, m, k, n)
12311220
sig === (Complex{Float64}, Complex{Float64})
12321221
return math_mode==CUDA.PEDANTIC_MATH ? CUBLAS_COMPUTE_64F_PEDANTIC : CUBLAS_COMPUTE_64F
12331222
end
1234-
1235-
# BFloat16 support was added in CUDA 11
1236-
if version() >= v"11"
1237-
if sig === (BFloat16, BFloat16) ||
1238-
sig === (BFloat16, Float32)
1239-
return math_mode==CUDA.PEDANTIC_MATH ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F
1240-
end
1223+
if sig === (BFloat16, BFloat16) || sig === (BFloat16, Float32)
1224+
return math_mode==CUDA.PEDANTIC_MATH ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F
12411225
end
12421226

12431227
return nothing
@@ -1263,20 +1247,11 @@ function gemmEx!(transA::Char, transB::Char,
12631247
isnothing(computeType) &&
12641248
throw(ArgumentError("gemmEx does not support $(eltype(C))=$(eltype(A))*$(eltype(B))"))
12651249
computeT = juliaStorageType(eltype(C), computeType)
1266-
if version() >= v"11.0"
1267-
# with CUDA 11, the compute type encodes the math mode.
1268-
cublasGemmEx(
1269-
handle(), transA, transB, m, n, k, CuRef{computeT}(alpha), A, eltype(A), lda, B,
1270-
eltype(B), ldb, CuRef{computeT}(beta), C, eltype(C), ldc, computeType, algo
1271-
)
1272-
else
1273-
# before CUDA 11, it was a plain cudaDataType.
1274-
computeType = convert(cudaDataType, computeT)
1275-
cublasGemmEx_old(
1276-
handle(), transA, transB, m, n, k, CuRef{computeT}(alpha), A, eltype(A), lda, B,
1277-
eltype(B), ldb, CuRef{computeT}(beta), C, eltype(C), ldc, computeType, algo
1278-
)
1279-
end
1250+
1251+
cublasGemmEx(
1252+
handle(), transA, transB, m, n, k, CuRef{computeT}(alpha), A, eltype(A), lda, B,
1253+
eltype(B), ldb, CuRef{computeT}(beta), C, eltype(C), ldc, computeType, algo
1254+
)
12801255
C
12811256
end
12821257

@@ -1311,15 +1286,11 @@ function gemmBatchedEx!(transA::Char, transB::Char,
13111286
Aptrs = unsafe_batch(A)
13121287
Bptrs = unsafe_batch(B)
13131288
Cptrs = unsafe_batch(C)
1314-
if version() >= v"11.0"
1315-
# with CUDA 11, the compute type encodes the math mode.
1316-
cublasGemmBatchedEx(
1317-
handle(), transA, transB, m, n, k, CuRef{computeT}(alpha), Aptrs, eltype(A[1]), lda, Bptrs,
1318-
eltype(B[1]), ldb, CuRef{computeT}(beta), Cptrs, eltype(C[1]), ldc, length(A), computeType, algo
1319-
)
1320-
else
1321-
error("Not implemented for CUDA 11 and below.")
1322-
end
1289+
1290+
cublasGemmBatchedEx(
1291+
handle(), transA, transB, m, n, k, CuRef{computeT}(alpha), Aptrs, eltype(A[1]), lda, Bptrs,
1292+
eltype(B[1]), ldb, CuRef{computeT}(beta), Cptrs, eltype(C[1]), ldc, length(A), computeType, algo
1293+
)
13231294
unsafe_free!(Cptrs)
13241295
unsafe_free!(Bptrs)
13251296
unsafe_free!(Aptrs)
@@ -1357,15 +1328,10 @@ function gemmStridedBatchedEx!(
13571328
isnothing(computeType) &&
13581329
throw(ArgumentError("gemmEx does not support $(eltype(C))=$(eltype(A))*$(eltype(B))"))
13591330
computeT = juliaStorageType(eltype(C), computeType)
1360-
if version() >= v"11.0"
1361-
# with CUDA 11, the compute type encodes the math mode.
1362-
cublasGemmStridedBatchedEx(
1363-
handle(), transA, transB, m, n, k, CuRef{computeT}(alpha), A, eltype(A), lda, strideA,
1364-
B, eltype(B), ldb, strideB, CuRef{computeT}(beta), C, eltype(C), ldc, strideC,
1365-
batchCount, computeType, algo)
1366-
else
1367-
error("Not implemented for CUDA 11 and below.")
1368-
end
1331+
cublasGemmStridedBatchedEx(
1332+
handle(), transA, transB, m, n, k, CuRef{computeT}(alpha), A, eltype(A), lda, strideA,
1333+
B, eltype(B), ldb, strideB, CuRef{computeT}(beta), C, eltype(C), ldc, strideC,
1334+
batchCount, computeType, algo)
13691335
C
13701336
end
13711337

@@ -1382,6 +1348,8 @@ end
13821348
#ptrs = [pointer(strided, (i-1)*batch_stride + 1) for i in 1:batch_size]
13831349
# fill the array on the GPU to avoid synchronous copies and support larger batch sizes
13841350
ptrs = CuArray{CuPtr{T}}(undef, batch_size)
1351+
# device-side code
1352+
## COV_EXCL_START
13851353
function compute_pointers()
13861354
i = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
13871355
grid_stride = gridDim().x * blockDim().x
@@ -1392,6 +1360,7 @@ end
13921360
end
13931361
return
13941362
end
1363+
## COV_EXCL_STOP
13951364
kernel = @cuda launch = false compute_pointers()
13961365
config = launch_configuration(kernel.fun)
13971366
threads = min(config.threads, batch_size)
@@ -2337,7 +2306,7 @@ for (fname, elty) in ((:cublasDgetriBatched, :Float64),
23372306
ldc = max(1, stride(C[1], 2))
23382307
Aptrs = unsafe_batch(A)
23392308
Cptrs = unsafe_batch(C)
2340-
info = CuArrays.zeros(Cint, length(A))
2309+
info = CUDA.zeros(Cint, length(A))
23412310
$fname(handle(), n, Aptrs, lda, pivotArray, Cptrs, ldc, info, length(A))
23422311
unsafe_free!(Cptrs)
23432312
unsafe_free!(Aptrs)

test/libraries/cublas/extensions.jl

Lines changed: 113 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ k = 13
5454
end
5555
@test inv(P)*dL*dU inv(C.P) * C.L * C.U
5656
end
57+
# generate bad matrices
58+
A_bad = vcat([rand(elty,m,m) for i in 1:9], [rand(elty, m, m+1)])
59+
# move to device
60+
d_A_bad = CuArray{elty, 2}[]
61+
for i in 1:length(A)
62+
push!(d_A_bad,CuArray(A_bad[i]))
63+
end
64+
@test_throws DimensionMismatch CUBLAS.getrf_batched!(d_A_bad, true)
5765
end
5866

5967
@testset "getrf_batched" begin
@@ -128,6 +136,11 @@ k = 13
128136
end
129137
@test inv(P)*dL*dU inv(C.P) * C.L * C.U
130138
end
139+
# generate bad strided matrix
140+
A = rand(elty,m,m+1,10)
141+
# move to device
142+
d_A = CuArray(A)
143+
@test_throws DimensionMismatch CUBLAS.getrf_strided_batched!(d_A, true)
131144
end
132145

133146
@testset "getrf_strided_batched" begin
@@ -168,24 +181,41 @@ k = 13
168181
for (opchar,opfun) in (('N',identity), ('T',transpose), ('C',adjoint))
169182

170183
@testset "getrs_batched!" begin
171-
A = [rand(elty,n,n) for _ in 1:k];
172-
d_A = [CuArray(a) for a in A];
173-
d_A2 = deepcopy(d_A);
174-
d_pivot, info, d_LU = CUDA.CUBLAS.getrf_batched!(d_A, true);
184+
A = [rand(elty,n,n) for _ in 1:k]
185+
d_A = [CuArray(a) for a in A]
186+
d_A2 = deepcopy(d_A)
187+
d_pivot, info, d_LU = CUDA.CUBLAS.getrf_batched!(d_A, true)
175188
@test d_LU == d_A
176-
d_pivot2 = similar(d_pivot);
177-
info2 = similar(info);
178-
CUDA.CUBLAS.getrf_batched!(d_A2, d_pivot2, info2);
189+
d_pivot2 = similar(d_pivot)
190+
info2 = similar(info)
191+
CUDA.CUBLAS.getrf_batched!(d_A2, d_pivot2, info2)
179192
@test isapprox(d_pivot, d_pivot2)
180193
@test isapprox(info, info2)
181-
B = [rand(elty,n,m) for _ in 1:k];
182-
d_B = [CuArray(b) for b in B];
183-
info2, d_Bhat = CUDA.CUBLAS.getrs_batched!(opchar, d_LU, d_B, d_pivot);
194+
B = [rand(elty,n,m) for _ in 1:k]
195+
d_B = [CuArray(b) for b in B]
196+
info2, d_Bhat = CUDA.CUBLAS.getrs_batched!(opchar, d_LU, d_B, d_pivot)
184197
@test d_Bhat == d_B
185-
h_Bhat = [collect(bh) for bh in d_Bhat];
198+
h_Bhat = [collect(bh) for bh in d_Bhat]
186199
for i in 1:k
187200
@test h_Bhat[i] opfun(A[i]) \ B[i]
188201
end
202+
203+
# generate bad matrices
204+
A_bad = vcat([rand(elty,m,m) for i in 1:9], [rand(elty, m, m+1)])
205+
# move to device
206+
d_A_bad = CuArray{elty, 2}[]
207+
for i in 1:length(A_bad)
208+
push!(d_A_bad,CuArray(A_bad[i]))
209+
end
210+
@test_throws DimensionMismatch CUBLAS.getrs_batched!(opchar, d_A_bad, d_B, d_pivot)
211+
# generate bad matrices
212+
A_bad = [rand(elty,m+1,m+1) for i in 1:10]
213+
# move to device
214+
d_A_bad = CuArray{elty, 2}[]
215+
for i in 1:length(A_bad)
216+
push!(d_A_bad,CuArray(A_bad[i]))
217+
end
218+
@test_throws DimensionMismatch CUBLAS.getrs_batched!(opchar, d_A_bad, d_B, d_pivot)
189219
end
190220

191221
@testset "getrs_batched" begin
@@ -210,24 +240,31 @@ k = 13
210240
end
211241

212242
@testset "getrs_strided_batched!" begin
213-
A = rand(elty,n,n,k);
214-
d_A = CuArray(A);
215-
d_A2 = copy(d_A);
216-
d_pivot, info, d_LU = CUDA.CUBLAS.getrf_strided_batched!(d_A, true);
243+
A = rand(elty,n,n,k)
244+
d_A = CuArray(A)
245+
d_A2 = copy(d_A)
246+
d_pivot, info, d_LU = CUDA.CUBLAS.getrf_strided_batched!(d_A, true)
217247
@test d_LU == d_A
218-
d_pivot2 = similar(d_pivot);
219-
info2 = similar(info);
220-
CUDA.CUBLAS.getrf_strided_batched!(d_A2, d_pivot2, info2);
248+
d_pivot2 = similar(d_pivot)
249+
info2 = similar(info)
250+
CUDA.CUBLAS.getrf_strided_batched!(d_A2, d_pivot2, info2)
221251
@test isapprox(d_pivot, d_pivot2)
222252
@test isapprox(info, info2)
223-
B = rand(elty,n,m,k);
224-
d_B = CuArray(B);
225-
info2, d_Bhat = CUDA.CUBLAS.getrs_strided_batched!(opchar, d_LU, d_B, d_pivot);
253+
B = rand(elty,n,m,k)
254+
d_B = CuArray(B)
255+
info2, d_Bhat = CUDA.CUBLAS.getrs_strided_batched!(opchar, d_LU, d_B, d_pivot)
226256
@test d_Bhat == d_B
227-
h_Bhat = collect(d_Bhat);
257+
h_Bhat = collect(d_Bhat)
228258
for i in 1:k
229259
@test h_Bhat[:,:,i] opfun(A[:,:,i]) \ B[:,:,i]
230260
end
261+
262+
A_bad = rand(elty,n+1,n,k)
263+
d_A_bad = CuArray(A_bad)
264+
@test_throws DimensionMismatch CUDA.CUBLAS.getrs_strided_batched!(opchar, d_A_bad, d_B, d_pivot)
265+
A_bad = rand(elty,n+1,n+1,k)
266+
d_A_bad = CuArray(A_bad)
267+
@test_throws DimensionMismatch CUDA.CUBLAS.getrs_strided_batched!(opchar, d_A_bad, d_B, d_pivot)
231268
end
232269

233270
@testset "getrs_strided_batched" begin
@@ -267,6 +304,12 @@ k = 13
267304
@test h_info[Cs] == 0
268305
@test B Array(d_B[:,:,Cs]) rtol=1e-3
269306
end
307+
308+
A_bad = rand(elty,m+1,m,10)
309+
d_A_bad = CuArray(A_bad)
310+
d_B = similar(d_A)
311+
pivot, info = CUBLAS.getrf_strided_batched!(d_A, true)
312+
@test_throws DimensionMismatch CUBLAS.getri_strided_batched!(d_A_bad, d_B, pivot)
270313
end
271314

272315
@testset "getri_batched" begin
@@ -290,6 +333,32 @@ k = 13
290333
@test h_info[Cs] == 0
291334
@test C h_C rtol=1e-2
292335
end
336+
337+
d_A = CuArray{elty, 2}[]
338+
for i in 1:length(A)
339+
push!(d_A,CuArray(A[i]))
340+
end
341+
pivot, info = CUBLAS.getrf_batched!(d_A, true)
342+
h_info = Array(info)
343+
for Cs in 1:length(h_info)
344+
@test h_info[Cs] == 0
345+
end
346+
d_C = CuMatrix{elty}[similar(d_A[1]) for i in 1:length(d_A)]
347+
info = CUBLAS.getri_batched!(d_A, d_C, pivot)
348+
h_info = Array(info)
349+
for Cs in 1:length(d_C)
350+
C = inv(A[Cs])
351+
h_C = Array(d_C[Cs])
352+
@test h_info[Cs] == 0
353+
@test C h_C rtol=1e-2
354+
end
355+
356+
A_bad = [rand(elty,m+1,m) for i in 1:10]
357+
d_A_bad = CuArray{elty, 2}[]
358+
for i in 1:length(A)
359+
push!(d_A_bad,CuArray(A_bad[i]))
360+
end
361+
@test_throws DimensionMismatch CUBLAS.getri_batched(d_A_bad, pivot)
293362
end
294363

295364
@testset "matinv_batched" begin
@@ -308,6 +377,15 @@ k = 13
308377
end
309378
push!(d_A, CUDA.rand(elty, m, m+1))
310379
@test_throws DimensionMismatch CUBLAS.matinv_batched(d_A)
380+
381+
# matinv_batched only supports matrices smaller than 32x32
382+
A = [rand(elty,64,64) for i in 1:10]
383+
# move to device
384+
d_A_too_big = CuArray{elty, 2}[]
385+
for i in 1:length(A)
386+
push!(d_A_too_big,CuArray(A[i]))
387+
end
388+
@test_throws ArgumentError("matinv requires all matrices be smaller than 32 x 32") CUBLAS.matinv_batched(d_A_too_big)
311389
end
312390

313391
@testset "geqrf_batched!" begin
@@ -343,7 +421,7 @@ k = 13
343421
for i in 1:length(A)
344422
push!(d_A,CuArray(A[i]))
345423
end
346-
tau, d_B = CUBLAS.geqrf_batched!(d_A)
424+
tau, d_B = CUBLAS.geqrf_batched(d_A)
347425
for Bs in 1:length(d_B)
348426
C = qr(A[Bs])
349427
h_B = Array(d_B[Bs])
@@ -392,6 +470,18 @@ k = 13
392470
end
393471
# system is now not overdetermined
394472
@test_throws ArgumentError CUBLAS.gels_batched!('N',d_A, d_C)
473+
474+
# generate bad matrices
475+
A = [rand(elty,n,k) for i in 1:10]
476+
C = [rand(elty,n+1,k) for i in 1:10]
477+
# move to device
478+
d_A = CuArray{elty, 2}[]
479+
d_C = CuArray{elty, 2}[]
480+
for i in 1:length(A)
481+
push!(d_A,CuArray(A[i]))
482+
push!(d_C,CuArray(C[i]))
483+
end
484+
@test_throws DimensionMismatch CUBLAS.gels_batched!('N',d_A, d_C)
395485
end
396486

397487
@testset "gels_batched" begin

test/libraries/cublas/level1.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,17 @@ k = 13
166166
@test testf(axpy!, rand(), rand(T, m), rand(T, m))
167167
@test testf(LinearAlgebra.axpby!, rand(), rand(T, m), rand(), rand(T, m))
168168

169+
@testset "scal!" begin
170+
x = rand(T, m)
171+
d_x = CuArray(x)
172+
α = rand(Float32)
173+
d_α = CuArray([α])
174+
y = α * x
175+
d_x = CUBLAS.scal!(m, d_α, d_x)
176+
h_y = Array(d_x)
177+
@test h_y y
178+
end
179+
169180
if T <: Complex
170181
@test testf(dot, rand(T, m), rand(T, m))
171182
x = rand(T, m)

0 commit comments

Comments
 (0)