Skip to content

Commit f1c4cdd

Browse files
authored
Update gemmStridedBatchedEx! size checks (#2935)
Align with `gemm_strided_batched!`.
1 parent 07a9672 commit f1c4cdd

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

lib/cublas/wrappers.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,9 +1306,8 @@ function gemmStridedBatchedEx!(
13061306
@nospecialize(beta),
13071307
@nospecialize(C::AbstractArray{Tc, 3});
13081308
algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT) where {Ta, Tb, Tc}
1309-
if size(A, 3) != size(B, 3) || size(A, 3) != size(C, 3)
1310-
throw(DimensionMismatch("Batch sizes must be equal for all inputs"))
1311-
end
1309+
@assert size(A, 3) == size(C, 3) || size(A, 3) == 1 "batch size mismatch: A != C"
1310+
@assert size(B, 3) == size(C, 3) || size(B, 3) == 1 "batch size mismatch: B != C"
13121311
m = size(A, transA == 'N' ? 1 : 2)
13131312
k = size(A, transA == 'N' ? 2 : 1)
13141313
n = size(B, transB == 'N' ? 2 : 1)

test/libraries/cublas/level3/gemm.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,14 @@ k = 13
268268
bA = rand(elty, m, k, nbatch)
269269
bB = rand(elty, k, n, nbatch)
270270
bC = rand(elty, m, n, nbatch)
271+
sB = rand(elty, k, n)
271272
bbad = rand(elty, m+1, n+1, nbatch)
272273
# move to device
273274
bd_A = CuArray{elty, 3}(bA)
274275
bd_B = CuArray{elty, 3}(bB)
275276
bd_C = CuArray{elty, 3}(bC)
277+
sd_B = CuArray{elty, 2}(sB)
278+
sdr_B = reshape(sd_B, size(sd_B)..., 1)
276279
bd_bad = CuArray{elty, 3}(bbad)
277280

278281
@testset "gemm_strided_batched!" begin
@@ -282,6 +285,14 @@ k = 13
282285
end
283286
h_C = Array(bd_C)
284287
@test bC h_C
288+
289+
CUBLAS.gemm_strided_batched!('N', 'N', alpha, bd_A, sdr_B, beta, bd_C)
290+
for i in 1:nbatch
291+
bC[:, :, i] = (alpha * bA[:, :, i]) * sB + beta * bC[:, :, i]
292+
end
293+
h_C = Array(bd_C)
294+
@test bC h_C
295+
285296
@test_throws DimensionMismatch CUBLAS.gemm_strided_batched!('N', 'N', alpha, bd_A, bd_B, beta, bd_bad)
286297
end
287298

@@ -292,6 +303,14 @@ k = 13
292303
end
293304
h_C = Array(bd_C)
294305
@test bC h_C
306+
307+
CUBLAS.gemmStridedBatchedEx!('N', 'N', alpha, bd_A, sdr_B, beta, bd_C)
308+
for i in 1:nbatch
309+
bC[:, :, i] = (alpha * bA[:, :, i]) * sB + beta * bC[:, :, i]
310+
end
311+
h_C = Array(bd_C)
312+
@test bC h_C
313+
295314
@test_throws DimensionMismatch CUBLAS.gemmStridedBatchedEx!('N', 'N', alpha, bd_A, bd_B, beta, bd_bad)
296315
end
297316

@@ -303,6 +322,14 @@ k = 13
303322
end
304323
h_C = Array(bd_C)
305324
@test bC h_C
325+
326+
bd_C = CUBLAS.gemm_strided_batched('N', 'N', bd_A, sdr_B)
327+
for i in 1:nbatch
328+
bC[:, :, i] = bA[:, :, i] * sB
329+
end
330+
h_C = Array(bd_C)
331+
@test bC h_C
332+
306333
# generate matrices
307334
bA = rand(elty, k, m, nbatch)
308335
bB = rand(elty, k, n, nbatch)

0 commit comments

Comments
 (0)