@@ -335,251 +335,90 @@ cl_int clBlashawaiiSgemmBig1024KernelFunctor::KernelsLaunch(cl_command_queue que
335
335
// #15: C22 = a*A23*B23 + 1*C22
336
336
// #16: C22 = a*A24*B24 + 1*C22 now we are done with C22
337
337
338
- unsigned int small_M = args.M / 2 ;
339
- unsigned int small_N = args.N / 2 ;
340
- unsigned int small_K = args.K / 4 ;
338
+ unsigned int K_split_factor = 4 ;
339
+ unsigned int M_split_factor = 2 ;
340
+ unsigned int N_split_factor = 2 ;
341
+
342
+ unsigned int small_M = args.M / M_split_factor;
343
+ unsigned int small_N = args.N / N_split_factor;
344
+ unsigned int small_K = args.K / K_split_factor;
341
345
342
346
size_t GlobalX = ((small_M - 1 ) / 128 + 1 ) * 16 ;
343
347
size_t GlobalY = ((small_N - 1 ) / 128 + 1 ) * 16 ;
344
348
std::size_t gs[2 ] = { GlobalX, GlobalY };
345
349
cl_int error = 0 ;
346
350
347
- // GEMM #1
351
+ cl_float betaone = 1 ;
352
+
348
353
error = clSetKernelArg (Kernel[0 ], 3 , sizeof (cl_uint), &small_M);
349
354
assert (error == CL_SUCCESS);
350
355
error = clSetKernelArg (Kernel[0 ], 4 , sizeof (cl_uint), &small_N);
351
356
assert (error == CL_SUCCESS);
352
357
error = clSetKernelArg (Kernel[0 ], 5 , sizeof (cl_uint), &small_K);
353
358
assert (error == CL_SUCCESS);
354
359
355
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
356
- gs, m_variantBig1024->ls , args.numEventsInWaitList , args.eventWaitList , NULL );
357
- assert (error == CL_SUCCESS);
358
-
359
- // GEMM #2: C11 = a*A12*B12 + 1*C11
360
-
361
- unsigned int offa_A2 = args.lda *args.K / 4 ;
362
- unsigned int offb_B2 = args.ldb *args.K / 4 ;
363
- cl_float betaone = 1 ;
364
-
365
- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
366
- assert (error == CL_SUCCESS);
367
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A2);
368
- assert (error == CL_SUCCESS);
369
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B2);
370
- assert (error == CL_SUCCESS);
371
-
372
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
373
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
374
- assert (error == CL_SUCCESS);
375
-
376
- // GEMM #3: C11 = a*A13*B13 + 1*C11
377
-
378
- unsigned int offa_A3 = args.lda *args.K / 4 * 2 ;
379
- unsigned int offb_B3 = args.ldb *args.K / 4 * 2 ;
380
-
381
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A3);
382
- assert (error == CL_SUCCESS);
383
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B3);
384
- assert (error == CL_SUCCESS);
385
-
386
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
387
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
388
- assert (error == CL_SUCCESS);
389
-
390
- // GEMM #4: C11 = a*A14*B14 + 1*C11 now we are done with C11
391
-
392
- unsigned int offa_A4 = args.lda *args.K / 4 * 3 ;
393
- unsigned int offb_B4 = args.ldb *args.K / 4 * 3 ;
394
-
395
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A4);
396
- assert (error == CL_SUCCESS);
397
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B4);
398
- assert (error == CL_SUCCESS);
399
-
400
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
401
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
402
- assert (error == CL_SUCCESS);
403
-
404
- // GEMM #5: C12 = a*A11*B21 + b*C12
405
- unsigned int offa_A5 = 0 ;
406
- unsigned int offb_B5 = args.N / 2 ;
407
- unsigned int offc_C5 = args.ldc *args.N / 2 ;
408
-
409
- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
410
- assert (error == CL_SUCCESS);
411
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A5);
412
- assert (error == CL_SUCCESS);
413
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B5);
414
- assert (error == CL_SUCCESS);
415
- error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C5);
416
- assert (error == CL_SUCCESS);
417
-
418
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
419
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
420
- assert (error == CL_SUCCESS);
421
-
422
- // GEMM #6: C12 = a*A12*B22 + 1*C12
423
- unsigned int offa_A6 = args.lda *args.K / 4 ;
424
- unsigned int offb_B6 = args.ldb *args.K / 4 + args.N / 2 ;
425
-
426
- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
427
- assert (error == CL_SUCCESS);
428
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A6);
429
- assert (error == CL_SUCCESS);
430
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B6);
431
- assert (error == CL_SUCCESS);
432
-
433
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
434
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
435
- assert (error == CL_SUCCESS);
436
-
437
- // GEMM #7: C12 = a*A13*B23 + 1*C12
438
- unsigned int offa_A7 = args.lda *args.K / 4 * 2 ;
439
- unsigned int offb_B7 = args.ldb *args.K / 4 * 2 + args.N / 2 ;
440
-
441
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A7);
442
- assert (error == CL_SUCCESS);
443
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B7);
444
- assert (error == CL_SUCCESS);
445
-
446
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
447
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
448
- assert (error == CL_SUCCESS);
449
-
450
- // GEMM #8: C12 = a*A14*B24 + 1*C12 now we are done with C12
451
- unsigned int offa_A8 = args.lda *args.K / 4 * 3 ;
452
- unsigned int offb_B8 = args.ldb *args.K / 4 * 3 + args.N / 2 ;
453
-
454
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A8);
455
- assert (error == CL_SUCCESS);
456
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B8);
457
- assert (error == CL_SUCCESS);
458
-
459
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
460
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
461
- assert (error == CL_SUCCESS);
462
-
463
- // GEMM #9: C21 = a*A21*B11 + b*C21
464
- unsigned int offa_A9 = args.M / 2 ;
465
- unsigned int offb_B9 = 0 ;
466
- unsigned int offc_C9 = args.M / 2 ;
467
-
468
- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
469
- assert (error == CL_SUCCESS);
470
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A9);
471
- assert (error == CL_SUCCESS);
472
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B9);
473
- assert (error == CL_SUCCESS);
474
- error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C9);
475
- assert (error == CL_SUCCESS);
476
-
477
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
478
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
479
- assert (error == CL_SUCCESS);
480
-
481
- // GEMM #10: C21 = a*A22*B12 + 1*C21
482
-
483
- unsigned int offa_A10 = args.lda *args.K / 4 + args.M / 2 ;
484
- unsigned int offb_B10 = args.ldb *args.K / 4 ;
485
-
486
- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
487
- assert (error == CL_SUCCESS);
488
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A10);
489
- assert (error == CL_SUCCESS);
490
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B10);
491
- assert (error == CL_SUCCESS);
492
-
493
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
494
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
495
- assert (error == CL_SUCCESS);
496
-
497
- // GEMM #11: C21 = a*A23*B13 + 1*C21
498
-
499
- unsigned int offa_A11 = args.lda *args.K / 4 * 2 + args.M / 2 ;
500
- unsigned int offb_B11 = args.ldb *args.K / 4 * 2 ;
501
-
502
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A11);
503
- assert (error == CL_SUCCESS);
504
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B11);
505
- assert (error == CL_SUCCESS);
506
-
507
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
508
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
509
- assert (error == CL_SUCCESS);
510
-
511
- // GEMM #12: C21 = a*A24*B14 + 1*C21 now we are done with C21
512
-
513
- unsigned int offa_A12 = args.lda *args.K / 4 * 3 + args.M / 2 ;
514
- unsigned int offb_B12 = args.ldb *args.K / 4 * 3 ;
515
-
516
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A12);
517
- assert (error == CL_SUCCESS);
518
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B12);
519
- assert (error == CL_SUCCESS);
520
-
521
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
522
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
523
- assert (error == CL_SUCCESS);
524
-
525
- // GEMM #13: C22 = a*A21*B21 + b*C22
526
- unsigned int offa_A13 = args.M / 2 ;
527
- unsigned int offb_B13 = args.N / 2 ;
528
- unsigned int offc_C13 = args.ldc *args.N / 2 + args.M / 2 ;
529
-
530
- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
531
- assert (error == CL_SUCCESS);
532
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A13);
533
- assert (error == CL_SUCCESS);
534
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B13);
535
- assert (error == CL_SUCCESS);
536
- error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C13);
537
- assert (error == CL_SUCCESS);
538
-
539
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
540
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
541
- assert (error == CL_SUCCESS);
542
-
543
- // #14: C22 = a*A22*B22 + 1*C22
544
- unsigned int offa_A14 = args.lda *args.K / 4 + args.M / 2 ;
545
- unsigned int offb_B14 = args.ldb *args.K / 4 + args.N / 2 ;
546
-
547
- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
548
- assert (error == CL_SUCCESS);
549
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A14);
550
- assert (error == CL_SUCCESS);
551
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B14);
552
- assert (error == CL_SUCCESS);
553
-
554
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
555
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
556
- assert (error == CL_SUCCESS);
557
-
558
- // #15: C22 = a*A23*B23 + 1*C22
559
- unsigned int offa_A15 = args.lda *args.K / 4 * 2 + args.M / 2 ;
560
- unsigned int offb_B15 = args.ldb *args.K / 4 * 2 + args.N / 2 ;
561
-
562
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A15);
563
- assert (error == CL_SUCCESS);
564
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B15);
565
- assert (error == CL_SUCCESS);
566
-
567
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
568
- gs, m_variantBig1024->ls , 0 , NULL , NULL );
569
- assert (error == CL_SUCCESS);
570
-
571
- // #16: C22 = a*A24*B24 + 1*C22
572
- unsigned int offa_A16 = args.lda *args.K / 4 * 3 + args.M / 2 ;
573
- unsigned int offb_B16 = args.ldb *args.K / 4 * 3 + args.N / 2 ;
574
-
575
- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A16);
576
- assert (error == CL_SUCCESS);
577
- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B16);
578
- assert (error == CL_SUCCESS);
360
+ for (int M_split_index = 0 ; M_split_index < M_split_factor; M_split_index++)
361
+ {
362
+ // 2 groups of GEMMs splited by M from example
363
+ for (int N_split_index = 0 ; N_split_index < N_split_factor; N_split_index++)
364
+ {
365
+ // 2 groups of GEMMs splited by N from example
366
+ unsigned int offc_C = args.ldc *args.N / N_split_factor * N_split_index + args.M / M_split_factor * M_split_index + args.offC ;
367
+ error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C);
368
+ assert (error == CL_SUCCESS);
369
+
370
+ for (int K_split_index = 0 ; K_split_index < K_split_factor; K_split_index++)
371
+ {
372
+ // 4 GEMMs splited by K from example
373
+ unsigned int offa_A = (args.M / M_split_factor * M_split_index) + (args.lda * args.K / K_split_factor * K_split_index) + args.offA ;
374
+ unsigned int offb_B = (args.N / N_split_factor * N_split_index) + (args.ldb * args.K / K_split_factor * K_split_index) + args.offB ;
375
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A);
376
+ assert (error == CL_SUCCESS);
377
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B);
378
+ assert (error == CL_SUCCESS);
379
+
380
+ if (K_split_index == 0 )
381
+ {
382
+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
383
+ assert (error == CL_SUCCESS);
384
+
385
+ if (M_split_index == 0 && N_split_index == 0 )
386
+ {
387
+ // very first GEMM call
388
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
389
+ gs, m_variantBig1024->ls , args.numEventsInWaitList , args.eventWaitList , NULL );
390
+ assert (error == CL_SUCCESS);
391
+ }
392
+ else
393
+ {
394
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
395
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
396
+ assert (error == CL_SUCCESS);
397
+ }
398
+ }
399
+ else
400
+ {
401
+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
402
+ assert (error == CL_SUCCESS);
403
+
404
+ if ((M_split_index == (M_split_factor - 1 ) ) && (N_split_index == (N_split_factor - 1 )) && (K_split_index == (K_split_factor - 1 )))
405
+ {
406
+ // very last GEMM call
407
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
408
+ gs, m_variantBig1024->ls , 0 , NULL , args.events );
409
+ assert (error == CL_SUCCESS);
410
+ }
411
+ else
412
+ {
413
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
414
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
415
+ assert (error == CL_SUCCESS);
416
+ }
417
+ }
418
+ }
419
+ }
420
+ }
579
421
580
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
581
- gs, m_variantBig1024->ls , 0 , NULL , args.events );
582
- assert (error == CL_SUCCESS);
583
422
584
423
return error;
585
424
}
0 commit comments