@@ -268,11 +268,14 @@ k = 13
268268 bA = rand (elty, m, k, nbatch)
269269 bB = rand (elty, k, n, nbatch)
270270 bC = rand (elty, m, n, nbatch)
271+ sB = rand (elty, k, n)
271272 bbad = rand (elty, m+ 1 , n+ 1 , nbatch)
272273 # move to device
273274 bd_A = CuArray {elty, 3} (bA)
274275 bd_B = CuArray {elty, 3} (bB)
275276 bd_C = CuArray {elty, 3} (bC)
277+ sd_B = CuArray {elty, 2} (sB)
278+ sdr_B = reshape (sd_B, size (sd_B)... , 1 )
276279 bd_bad = CuArray {elty, 3} (bbad)
277280
278281 @testset " gemm_strided_batched!" begin
@@ -282,6 +285,14 @@ k = 13
282285 end
283286 h_C = Array (bd_C)
284287 @test bC ≈ h_C
288+
289+ CUBLAS. gemm_strided_batched! (' N' , ' N' , alpha, bd_A, sdr_B, beta, bd_C)
290+ for i in 1 : nbatch
291+ bC[:, :, i] = (alpha * bA[:, :, i]) * sB + beta * bC[:, :, i]
292+ end
293+ h_C = Array (bd_C)
294+ @test bC ≈ h_C
295+
285296 @test_throws DimensionMismatch CUBLAS. gemm_strided_batched! (' N' , ' N' , alpha, bd_A, bd_B, beta, bd_bad)
286297 end
287298
@@ -292,6 +303,14 @@ k = 13
292303 end
293304 h_C = Array (bd_C)
294305 @test bC ≈ h_C
306+
307+ CUBLAS. gemmStridedBatchedEx! (' N' , ' N' , alpha, bd_A, sdr_B, beta, bd_C)
308+ for i in 1 : nbatch
309+ bC[:, :, i] = (alpha * bA[:, :, i]) * sB + beta * bC[:, :, i]
310+ end
311+ h_C = Array (bd_C)
312+ @test bC ≈ h_C
313+
295314 @test_throws DimensionMismatch CUBLAS. gemmStridedBatchedEx! (' N' , ' N' , alpha, bd_A, bd_B, beta, bd_bad)
296315 end
297316
@@ -303,6 +322,14 @@ k = 13
303322 end
304323 h_C = Array (bd_C)
305324 @test bC ≈ h_C
325+
326+ bd_C = CUBLAS. gemm_strided_batched (' N' , ' N' , bd_A, sdr_B)
327+ for i in 1 : nbatch
328+ bC[:, :, i] = bA[:, :, i] * sB
329+ end
330+ h_C = Array (bd_C)
331+ @test bC ≈ h_C
332+
306333 # generate matrices
307334 bA = rand (elty, k, m, nbatch)
308335 bB = rand (elty, k, n, nbatch)
0 commit comments