Skip to content

Commit 6653404

Browse files
authored
Merge pull request #2657 from JuliaGPU/ksh/cublas_cov
Lots more tests for CUBLAS
2 parents bd7a282 + 19ffb51 commit 6653404

File tree

3 files changed

+39
-6
lines changed

3 files changed

+39
-6
lines changed

test/libraries/cublas/level2.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,24 @@ k = 13
6161
dA = CuArray{elty, 2}[]
6262
dy = CuArray{elty, 1}[]
6363
dbad = CuArray{elty, 1}[]
64+
dx_bad = CuArray{elty, 1}[]
65+
dA_bad = CuArray{elty, 2}[]
6466
for i=1:length(A)
6567
push!(dA, CuArray(A[i]))
6668
push!(dx, CuArray(x[i]))
6769
push!(dy, CuArray(y[i]))
6870
if i < length(A) - 2
6971
push!(dbad,CuArray(dx[i]))
72+
push!(dx_bad,CuArray(dx[i]))
73+
push!(dA_bad,CuArray(A[i]))
74+
else
75+
push!(dx_bad,CUDA.rand(elty, m+1))
76+
push!(dA_bad,CUDA.rand(elty, n+1, m+1))
7077
end
7178
end
7279
@test_throws DimensionMismatch CUBLAS.gemv_batched!('N', alpha, dA, dx, beta, dbad)
80+
@test_throws DimensionMismatch CUBLAS.gemv_batched!('N', alpha, dA, dx_bad, beta, dy)
81+
@test_throws DimensionMismatch CUBLAS.gemv_batched!('N', alpha, dA_bad, dx, beta, dy)
7382
CUBLAS.gemv_batched!('N', alpha, dA, dx, beta, dy)
7483
for i=1:length(A)
7584
hy = collect(dy[i])
@@ -304,6 +313,7 @@ k = 13
304313
dx = CuArray(x)
305314

306315
function pack(A, uplo)
316+
n = size(A, 1)
307317
AP = Vector{elty}(undef, (n*(n+1))>>1)
308318
k = 1
309319
for j in 1:n
@@ -315,7 +325,7 @@ k = 13
315325
return AP
316326
end
317327

318-
if elty in ["Float32", "Float64"]
328+
if elty <: Real
319329
# pack matrices
320330
sAPU = pack(sA, :U)
321331
dsAPU = CuVector(sAPU)
@@ -337,9 +347,9 @@ k = 13
337347
hy = Array(dy)
338348
@test y hy
339349
# execute on host
340-
BLAS.spmv!('U',alpha,sAPL,x,beta,y)
350+
BLAS.spmv!('L',alpha,sAPL,x,beta,y)
341351
# execute on device
342-
CUBLAS.spmv!('U',alpha,dsAPL,dx,beta,dy)
352+
CUBLAS.spmv!('L',alpha,dsAPL,dx,beta,dy)
343353
# compare results
344354
hy = Array(dy)
345355
@test y hy
@@ -356,11 +366,11 @@ k = 13
356366
hsAPU = Array(dsAPU)
357367
@test sAPU hsAPU
358368
# execute on host
359-
BLAS.spr!('U',alpha,x,sAPL)
369+
BLAS.spr!('L',alpha,x,sAPL)
360370
# execute on device
361-
CUBLAS.spr!('U',alpha,dx,dsAPL)
371+
CUBLAS.spr!('L',alpha,dx,dsAPL)
362372
# compare results
363-
hAPL = Array(dAPL)
373+
hAPL = Array(dsAPL)
364374
@test sAPL hAPL
365375
end
366376
end

test/libraries/cublas/level3.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,10 @@ k = 13
343343
# move to host and compare
344344
h_C = Array(d_C)
345345
@test D h_C
346+
d_C = CUBLAS.geam('N','N',d_A,d_B)
347+
h_C = Array(d_C)
348+
D = A + B
349+
@test D h_C
346350
end
347351
@testset "CuMatrix -- A ± B -- $elty" begin
348352
for opa in (identity, transpose, adjoint)
@@ -502,6 +506,9 @@ k = 13
502506
h_C = triu(h_C)
503507
@test C h_C
504508
@test_throws DimensionMismatch CUBLAS.her2k!('U','N',α,d_A,d_Bbad,β,d_C)
509+
Bbad = rand(elty,m,k+1)
510+
d_Bbad = CuArray(Bbad)
511+
@test_throws DimensionMismatch CUBLAS.her2k!('U','N',α,d_A,d_Bbad,β,d_C)
505512
end
506513
@testset "her2k" begin
507514
α = rand(elty)

test/libraries/cublas/level3/gemm.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ k = 13
6262
@test C1 h_C1
6363
@test C2 h_C2
6464
@test_throws ArgumentError mul!(dhA, dhA, dsA)
65+
@test_throws DimensionMismatch CUBLAS.gemm!('N','N',one(elty),d_A,dsA,one(elty),d_C1)
6566
@test_throws DimensionMismatch mul!(d_C1, d_A, dsA)
6667
end
6768
@testset "strided gemm!" begin
@@ -95,6 +96,8 @@ k = 13
9596
C1 =*A)*B + β*C1
9697
# compare
9798
@test C1 h_C1
99+
d_Cbad = CUDA.zeros(elty, m+1, n-1)
100+
@test_throws DimensionMismatch CUBLAS.gemmEx!('N','N',α,d_A,d_B,β,d_Cbad)
98101
end
99102
end
100103
@testset "gemm" begin
@@ -194,13 +197,17 @@ k = 13
194197
bd_A = CuArray{elty, 2}[]
195198
bd_B = CuArray{elty, 2}[]
196199
bd_C = CuArray{elty, 2}[]
200+
bd_A_bad = CuArray{elty, 2}[]
197201
bd_bad = CuArray{elty, 2}[]
198202
for i in 1:length(bA)
199203
push!(bd_A,CuArray(bA[i]))
200204
push!(bd_B,CuArray(bB[i]))
201205
push!(bd_C,CuArray(bC[i]))
202206
if i < length(bA) - 2
203207
push!(bd_bad,CuArray(bC[i]))
208+
push!(bd_A_bad,CuArray(bA[i]))
209+
else
210+
push!(bd_A_bad,CUDA.rand(elty, m+1, k-1))
204211
end
205212
end
206213

@@ -214,6 +221,7 @@ k = 13
214221
@test bC[i] h_C
215222
end
216223
@test_throws DimensionMismatch CUBLAS.gemm_batched!('N','N',alpha,bd_A,bd_bad,beta,bd_C)
224+
@test_throws DimensionMismatch CUBLAS.gemm_batched!('N','N',alpha,bd_A_bad,bd_B,beta,bd_C)
217225
end
218226

219227
@testset "gemm_batched" begin
@@ -224,6 +232,7 @@ k = 13
224232
@test bC[i] h_C
225233
end
226234
@test_throws DimensionMismatch CUBLAS.gemm_batched('N','N',alpha,bd_A,bd_bad)
235+
@test_throws DimensionMismatch CUBLAS.gemm_batched('N','N',alpha,bd_A_bad,bd_B)
227236
end
228237

229238
@testset "gemmBatchedEx!" begin
@@ -236,6 +245,7 @@ k = 13
236245
@test bC[i] h_C
237246
end
238247
@test_throws DimensionMismatch CUBLAS.gemmBatchedEx!('N','N',alpha,bd_A,bd_bad,beta,bd_C)
248+
@test_throws DimensionMismatch CUBLAS.gemmBatchedEx!('N','N',alpha,bd_A_bad,bd_B,beta,bd_C)
239249
end
240250

241251
nbatch = 10
@@ -311,6 +321,9 @@ k = 13
311321
bd_A = [[CuArray(bA[i][j]) for j in 1:group_sizes[i]] for i in 1:num_groups]
312322
bd_B = [[CuArray(bB[i][j]) for j in 1:group_sizes[i]] for i in 1:num_groups]
313323
bd_C = [[CuArray(bC[i][j]) for j in 1:group_sizes[i]] for i in 1:num_groups]
324+
bd_A_bad1 = [[CuArray(bA[i][j]) for j in 1:group_sizes[i]] for i in 1:num_groups-1]
325+
bd_A_bad2 = [[CuArray(bA[i][j]) for j in 1:group_sizes[i]-1] for i in 1:num_groups]
326+
bd_A_bad3 = [[CUDA.rand(elty, 3*i+1,2*i - 1) for j in 1:group_sizes[i]] for i in 1:num_groups]
314327
@testset "gemm_grouped_batched!" begin
315328
# C = (alpha*A)*B + beta*C
316329
CUBLAS.gemm_grouped_batched!(transA,transB,alpha,bd_A,bd_B,beta,bd_C)
@@ -319,6 +332,9 @@ k = 13
319332
h_C = Array(bd_C[i][j])
320333
@test bC[i][j] h_C
321334
end
335+
@test_throws DimensionMismatch CUBLAS.gemm_grouped_batched!(transA,transB,alpha,bd_A_bad1,bd_B,beta,bd_C)
336+
@test_throws DimensionMismatch CUBLAS.gemm_grouped_batched!(transA,transB,alpha,bd_A_bad2,bd_B,beta,bd_C)
337+
@test_throws DimensionMismatch CUBLAS.gemm_grouped_batched!(transA,transB,alpha,bd_A_bad3,bd_B,beta,bd_C)
322338
end
323339

324340
@testset "gemm_grouped_batched" begin

0 commit comments

Comments
 (0)