@@ -61,6 +61,9 @@ static void force_gemm_column_major(
61
61
printf (" OpenCL error %i on line %u of %s\n " , RET, __LINE__, __FILE__); \
62
62
assert (false ); \
63
63
}
64
+ #define returnIfErr (err ) \
65
+ if (err != CL_SUCCESS)\
66
+ return static_cast <clblasStatus>(err);
64
67
65
68
const static unsigned int numGemmKernelArgs = 14 ;
66
69
void *gemmKernelArgs[numGemmKernelArgs];
@@ -258,7 +261,7 @@ void makeGemmKernel(
258
261
/* *****************************************************************************
259
262
* Enqueue Gemm Kernel
260
263
*****************************************************************************/
261
- void enqueueGemmKernel (
264
+ cl_int enqueueGemmKernel (
262
265
cl_command_queue clQueue,
263
266
cl_kernel clKernel,
264
267
void **kernelArgs,
@@ -271,14 +274,20 @@ void makeGemmKernel(
271
274
cl_event *clEvent)
272
275
{
273
276
for (unsigned int i = 0 ; i < numKernelArgs; i++) {
274
- CL_CHECK ( clSetKernelArg ( clKernel, i, kernelArgSizes[i], kernelArgs[i]) )
277
+ cl_int err = clSetKernelArg (clKernel, i, kernelArgSizes[i], kernelArgs[i]);
278
+ if (err != CL_SUCCESS)
279
+ return err;
275
280
}
276
281
/* printf("global={%llu, %llu} local={%llu, %llu}\n",
277
282
globalWorkSize[0], globalWorkSize[1],
278
283
localWorkSize[0], localWorkSize[1] );*/
279
- CL_CHECK ( clEnqueueNDRangeKernel ( clQueue, clKernel,
280
- 2 , NULL , globalWorkSize, localWorkSize,
281
- numEventsInWaitList, eventWaitList, clEvent ) )
284
+ cl_uint err = clEnqueueNDRangeKernel (clQueue, clKernel,
285
+ 2 , NULL , globalWorkSize, localWorkSize,
286
+ numEventsInWaitList, eventWaitList, clEvent);
287
+ if (err != CL_SUCCESS)
288
+ return err;
289
+
290
+ return CL_SUCCESS;
282
291
}
283
292
284
293
@@ -325,6 +334,8 @@ clblasGemm(
325
334
const cl_event *eventWaitList,
326
335
cl_event *events)
327
336
{
337
+
338
+
328
339
// cast types to opencl types
329
340
cl_mem A = iA;
330
341
cl_mem B = iB;
@@ -389,10 +400,13 @@ clblasGemm(
389
400
cl_int err;
390
401
cl_device_id clDevice;
391
402
err = clGetCommandQueueInfo ( commandQueues[0 ], CL_QUEUE_DEVICE, sizeof (clDevice), &clDevice, NULL );
392
- CL_CHECK (err)
403
+ // CL_CHECK(err)
404
+ returnIfErr (err);
405
+
393
406
cl_uint clDeviceNumCUs;
394
407
err = clGetDeviceInfo ( clDevice, CL_DEVICE_MAX_COMPUTE_UNITS, sizeof (clDeviceNumCUs), &clDeviceNumCUs, NULL );
395
- CL_CHECK (err)
408
+ // CL_CHECK(err)
409
+ returnIfErr (err);
396
410
unsigned int deviceIdealNumThreads = (8 /* waves per CU*/ )*(64 /* threads per wave*/ )*clDeviceNumCUs;
397
411
float optimalNumElementsPerThread = ((float )M*N) / deviceIdealNumThreads;
398
412
// optimalNumElementsPerThread = 32;
@@ -562,11 +576,12 @@ clblasGemm(
562
576
if (needTileKernel) {
563
577
// printf("enqueueing tile kernel\n");
564
578
size_t globalWorkSize[2 ] = {(M/macroTileNumRows)*workGroupNumRows, (N/macroTileNumCols)*workGroupNumCols };
565
- enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], *tileClKernel,
579
+ err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], *tileClKernel,
566
580
gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
567
581
globalWorkSize, localWorkSize,
568
582
numEventsInWaitList, eventWaitList,
569
583
&events[numKernelsEnqueued%numCommandQueues] );
584
+ returnIfErr (err);
570
585
numKernelsEnqueued++;
571
586
}
572
587
@@ -576,11 +591,12 @@ clblasGemm(
576
591
if (needRowKernel) {
577
592
// printf("enqueueing row kernel\n");
578
593
size_t globalWorkSize[2 ] = {1 *workGroupNumRows, (N/macroTileNumCols)*workGroupNumCols };
579
- enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], *rowClKernel,
594
+ err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], *rowClKernel,
580
595
gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
581
596
globalWorkSize, localWorkSize,
582
597
numEventsInWaitList, eventWaitList,
583
598
&events[numKernelsEnqueued%numCommandQueues] );
599
+ returnIfErr (err);
584
600
numKernelsEnqueued++;
585
601
}
586
602
@@ -590,11 +606,12 @@ clblasGemm(
590
606
if (needColKernel) {
591
607
// printf("enqueueing col kernel\n");
592
608
size_t globalWorkSize[2 ] = { (M/macroTileNumRows)*workGroupNumRows, 1 *workGroupNumCols };
593
- enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], *colClKernel,
609
+ err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], *colClKernel,
594
610
gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
595
611
globalWorkSize, localWorkSize,
596
612
numEventsInWaitList, eventWaitList,
597
613
&events[numKernelsEnqueued%numCommandQueues] );
614
+ returnIfErr (err);
598
615
numKernelsEnqueued++;
599
616
}
600
617
@@ -604,11 +621,12 @@ clblasGemm(
604
621
if (needCornerKernel) {
605
622
// printf("enqueueing corner kernel\n");
606
623
size_t globalWorkSize[2 ] = { 1 *workGroupNumRows, 1 *workGroupNumCols };
607
- enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], *cornerClKernel,
624
+ err = enqueueGemmKernel ( commandQueues[numKernelsEnqueued%numCommandQueues], *cornerClKernel,
608
625
gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs,
609
626
globalWorkSize, localWorkSize,
610
627
numEventsInWaitList, eventWaitList,
611
628
&events[numKernelsEnqueued%numCommandQueues] );
629
+ returnIfErr (err);
612
630
numKernelsEnqueued++;
613
631
}
614
632
@@ -637,6 +655,29 @@ clblasSgemm(
637
655
const cl_event *eventWaitList,
638
656
cl_event *events)
639
657
{
658
+ // check if memory objects are valid
659
+ clblasStatus clblasErr = clblasSuccess;
660
+ clblasErr = checkMemObjects (A, B, C, true , A_MAT_ERRSET, B_MAT_ERRSET, C_MAT_ERRSET);
661
+ if (clblasErr != clblasSuccess)
662
+ return clblasErr;
663
+
664
+ if (K != 0 )
665
+ {
666
+ // check matrix A
667
+ clblasErr = checkMatrixSizes (TYPE_FLOAT, order, transA, M, K, A, offA, lda, A_MAT_ERRSET);
668
+ if (clblasErr != clblasSuccess)
669
+ return clblasErr;
670
+
671
+ // check matrix B
672
+ clblasErr = checkMatrixSizes (TYPE_FLOAT, order, transB, K, N, B, offB, ldb, B_MAT_ERRSET);
673
+ if (clblasErr != clblasSuccess)
674
+ return clblasErr;
675
+ }
676
+ // check matrix C
677
+ clblasErr = checkMatrixSizes (TYPE_FLOAT, order, clblasNoTrans, M, N, C, offC, ldc, C_MAT_ERRSET);
678
+ if (clblasErr != clblasSuccess)
679
+ return clblasErr;
680
+
640
681
return clblasGemm (
641
682
order,
642
683
transA,
@@ -674,6 +715,29 @@ clblasDgemm( clblasOrder order,
674
715
const cl_event *eventWaitList,
675
716
cl_event *events)
676
717
{
718
+ // check if memory objects are valid
719
+ clblasStatus clblasErr = clblasSuccess;
720
+ clblasErr = checkMemObjects (A, B, C, true , A_MAT_ERRSET, B_MAT_ERRSET, C_MAT_ERRSET);
721
+ if (clblasErr != clblasSuccess)
722
+ return clblasErr;
723
+
724
+ if (K != 0 )
725
+ {
726
+ // check matrix A
727
+ clblasErr = checkMatrixSizes (TYPE_DOUBLE, order, transA, M, K, A, offA, lda, A_MAT_ERRSET);
728
+ if (clblasErr != clblasSuccess)
729
+ return clblasErr;
730
+
731
+ // check matrix B
732
+ clblasErr = checkMatrixSizes (TYPE_DOUBLE, order, transB, K, N, B, offB, ldb, B_MAT_ERRSET);
733
+ if (clblasErr != clblasSuccess)
734
+ return clblasErr;
735
+ }
736
+ // check matrix C
737
+ clblasErr = checkMatrixSizes (TYPE_DOUBLE, order, clblasNoTrans, M, N, C, offC, ldc, C_MAT_ERRSET);
738
+ if (clblasErr != clblasSuccess)
739
+ return clblasErr;
740
+
677
741
return clblasGemm (
678
742
order,
679
743
transA,
@@ -712,6 +776,29 @@ clblasCgemm(
712
776
const cl_event *eventWaitList,
713
777
cl_event *events)
714
778
{
779
+ // check if memory objects are valid
780
+ clblasStatus clblasErr = clblasSuccess;
781
+ clblasErr = checkMemObjects (A, B, C, true , A_MAT_ERRSET, B_MAT_ERRSET, C_MAT_ERRSET);
782
+ if (clblasErr != clblasSuccess)
783
+ return clblasErr;
784
+
785
+ if (K != 0 )
786
+ {
787
+ // check matrix A
788
+ clblasErr = checkMatrixSizes (TYPE_COMPLEX_FLOAT, order, transA, M, K, A, offA, lda, A_MAT_ERRSET);
789
+ if (clblasErr != clblasSuccess)
790
+ return clblasErr;
791
+
792
+ // check matrix B
793
+ clblasErr = checkMatrixSizes (TYPE_COMPLEX_FLOAT, order, transB, K, N, B, offB, ldb, B_MAT_ERRSET);
794
+ if (clblasErr != clblasSuccess)
795
+ return clblasErr;
796
+ }
797
+ // check matrix C
798
+ clblasErr = checkMatrixSizes (TYPE_COMPLEX_FLOAT, order, clblasNoTrans, M, N, C, offC, ldc, C_MAT_ERRSET);
799
+ if (clblasErr != clblasSuccess)
800
+ return clblasErr;
801
+
715
802
return clblasGemm (
716
803
order,
717
804
transA,
@@ -750,6 +837,29 @@ clblasZgemm(
750
837
const cl_event *eventWaitList,
751
838
cl_event *events)
752
839
{
840
+ // check if memory objects are valid
841
+ clblasStatus clblasErr = clblasSuccess;
842
+ clblasErr = checkMemObjects (A, B, C, true , A_MAT_ERRSET, B_MAT_ERRSET, C_MAT_ERRSET);
843
+ if (clblasErr != clblasSuccess)
844
+ return clblasErr;
845
+
846
+ if (K != 0 )
847
+ {
848
+ // check matrix A
849
+ clblasErr = checkMatrixSizes (TYPE_COMPLEX_DOUBLE, order, transA, M, K, A, offA, lda, A_MAT_ERRSET);
850
+ if (clblasErr != clblasSuccess)
851
+ return clblasErr;
852
+
853
+ // check matrix B
854
+ clblasErr = checkMatrixSizes (TYPE_COMPLEX_DOUBLE, order, transB, K, N, B, offB, ldb, B_MAT_ERRSET);
855
+ if (clblasErr != clblasSuccess)
856
+ return clblasErr;
857
+ }
858
+ // check matrix C
859
+ clblasErr = checkMatrixSizes (TYPE_COMPLEX_DOUBLE, order, clblasNoTrans, M, N, C, offC, ldc, C_MAT_ERRSET);
860
+ if (clblasErr != clblasSuccess)
861
+ return clblasErr;
862
+
753
863
return clblasGemm (
754
864
order,
755
865
transA,
0 commit comments