@@ -289,21 +289,300 @@ clBlashawaiiSgemmBig1024KernelFunctor::provide(clblasSgemmFunctor::Args & args,
289
289
290
290
cl_int clBlashawaiiSgemmBig1024KernelFunctor::KernelsLaunch (cl_command_queue queue, cl_kernel Kernel[1 ], Args &args)
291
291
{
292
- // ((Mvalue - 1) / 128 + 1) * 16
293
- size_t GlobalX = ((args.M -1 ) / 128 + 1 )*16 ;
294
-
295
- //
296
-
297
- size_t GlobalY = ((args.N - 1 ) / 128 + 1 ) * 16 ;
292
+ if (args.lda < 7168 )
293
+ {
294
+ // ((Mvalue - 1) / 128 + 1) * 16
295
+ size_t GlobalX = ((args.M - 1 ) / 128 + 1 ) * 16 ;
296
+ size_t GlobalY = ((args.N - 1 ) / 128 + 1 ) * 16 ;
298
297
299
298
300
- std::size_t gs[2 ] = {GlobalX, GlobalY};
301
- cl_int error = 0 ;
299
+ std::size_t gs[2 ] = { GlobalX, GlobalY };
300
+ cl_int error = 0 ;
302
301
303
302
304
- // if (VERB) printf(" ===> EXECUTE KERNEL 0 \n") ;
305
- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantBig1024->ls , args.numEventsInWaitList , args.eventWaitList ,args.events );
306
- return error;
303
+ // if (VERB) printf(" ===> EXECUTE KERNEL 0 \n") ;
304
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantBig1024->ls , args.numEventsInWaitList , args.eventWaitList , args.events );
305
+ return error;
306
+ }
307
+ else
308
+ {
309
+ // for example, when M=N=K=8192
310
+ // we are gonna call 16 GEMMs
311
+ // each GEMM has M=N=K=4096
312
+ // note are direct GEMM call has a 0.7 TFLOPS performance
313
+
314
+ // [ A11 | A12 | A13 | A14 ] [ B11 | B12 | B13 | B14 ] [ C11 | C12 ]
315
+ // A = [ A21 | A22 | A23 | A24 ] B = [ B21 | B22 | B23 | B24 ] C = [ C21 | C22 ]
316
+
317
+ // 16 GEMMs are
318
+ // #01: C11 = a*A11*B11 + b*C11
319
+ // #02: C11 = a*A12*B12 + 1*C11
320
+ // #03: C11 = a*A13*B13 + 1*C11
321
+ // #04: C11 = a*A14*B14 + 1*C11 now we are done with C11
322
+
323
+ // #05: C12 = a*A11*B21 + b*C12
324
+ // #06: C12 = a*A12*B22 + 1*C12
325
+ // #07: C12 = a*A12*B22 + 1*C12
326
+ // #08: C12 = a*A12*B22 + 1*C12 now we are done with C12
327
+
328
+ // #09: C21 = a*A21*B11 + b*C21
329
+ // #10: C21 = a*A22*B12 + 1*C21
330
+ // #11: C21 = a*A23*B13 + 1*C21
331
+ // #12: C21 = a*A24*B14 + 1*C21 now we are done with C21
332
+
333
+ // #13: C22 = a*A21*B21 + b*C22
334
+ // #14: C22 = a*A22*B22 + 1*C22
335
+ // #15: C22 = a*A23*B23 + 1*C22
336
+ // #16: C22 = a*A24*B24 + 1*C22 now we are done with C22
337
+
338
+ unsigned int small_M = args.M / 2 ;
339
+ unsigned int small_N = args.N / 2 ;
340
+ unsigned int small_K = args.K / 4 ;
341
+
342
+ size_t GlobalX = ((small_M - 1 ) / 128 + 1 ) * 16 ;
343
+ size_t GlobalY = ((small_N - 1 ) / 128 + 1 ) * 16 ;
344
+ std::size_t gs[2 ] = { GlobalX, GlobalY };
345
+ cl_int error = 0 ;
346
+
347
+ // GEMM #1
348
+ error = clSetKernelArg (Kernel[0 ], 3 , sizeof (cl_uint), &small_M);
349
+ assert (error == CL_SUCCESS);
350
+ error = clSetKernelArg (Kernel[0 ], 4 , sizeof (cl_uint), &small_N);
351
+ assert (error == CL_SUCCESS);
352
+ error = clSetKernelArg (Kernel[0 ], 5 , sizeof (cl_uint), &small_K);
353
+ assert (error == CL_SUCCESS);
354
+
355
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
356
+ gs, m_variantBig1024->ls , args.numEventsInWaitList , args.eventWaitList , NULL );
357
+ assert (error == CL_SUCCESS);
358
+
359
+ // GEMM #2: C11 = a*A12*B12 + 1*C11
360
+
361
+ unsigned int offa_A2 = args.lda *args.K / 4 ;
362
+ unsigned int offb_B2 = args.ldb *args.K / 4 ;
363
+ cl_float betaone = 1 ;
364
+
365
+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
366
+ assert (error == CL_SUCCESS);
367
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A2);
368
+ assert (error == CL_SUCCESS);
369
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B2);
370
+ assert (error == CL_SUCCESS);
371
+
372
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
373
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
374
+ assert (error == CL_SUCCESS);
375
+
376
+ // GEMM #3: C11 = a*A13*B13 + 1*C11
377
+
378
+ unsigned int offa_A3 = args.lda *args.K / 4 * 2 ;
379
+ unsigned int offb_B3 = args.ldb *args.K / 4 * 2 ;
380
+
381
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A3);
382
+ assert (error == CL_SUCCESS);
383
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B3);
384
+ assert (error == CL_SUCCESS);
385
+
386
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
387
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
388
+ assert (error == CL_SUCCESS);
389
+
390
+ // GEMM #4: C11 = a*A14*B14 + 1*C11 now we are done with C11
391
+
392
+ unsigned int offa_A4 = args.lda *args.K / 4 * 3 ;
393
+ unsigned int offb_B4 = args.ldb *args.K / 4 * 3 ;
394
+
395
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A4);
396
+ assert (error == CL_SUCCESS);
397
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B4);
398
+ assert (error == CL_SUCCESS);
399
+
400
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
401
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
402
+ assert (error == CL_SUCCESS);
403
+
404
+ // GEMM #5: C12 = a*A11*B21 + b*C12
405
+ unsigned int offa_A5 = 0 ;
406
+ unsigned int offb_B5 = args.N / 2 ;
407
+ unsigned int offc_C5 = args.ldc *args.N / 2 ;
408
+
409
+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
410
+ assert (error == CL_SUCCESS);
411
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A5);
412
+ assert (error == CL_SUCCESS);
413
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B5);
414
+ assert (error == CL_SUCCESS);
415
+ error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C5);
416
+ assert (error == CL_SUCCESS);
417
+
418
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
419
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
420
+ assert (error == CL_SUCCESS);
421
+
422
+ // GEMM #6: C12 = a*A12*B22 + 1*C12
423
+ unsigned int offa_A6 = args.lda *args.K / 4 ;
424
+ unsigned int offb_B6 = args.ldb *args.K / 4 + args.N / 2 ;
425
+
426
+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
427
+ assert (error == CL_SUCCESS);
428
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A6);
429
+ assert (error == CL_SUCCESS);
430
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B6);
431
+ assert (error == CL_SUCCESS);
432
+
433
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
434
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
435
+ assert (error == CL_SUCCESS);
436
+
437
+ // GEMM #7: C12 = a*A13*B23 + 1*C12
438
+ unsigned int offa_A7 = args.lda *args.K / 4 * 2 ;
439
+ unsigned int offb_B7 = args.ldb *args.K / 4 * 2 + args.N / 2 ;
440
+
441
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A7);
442
+ assert (error == CL_SUCCESS);
443
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B7);
444
+ assert (error == CL_SUCCESS);
445
+
446
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
447
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
448
+ assert (error == CL_SUCCESS);
449
+
450
+ // GEMM #8: C12 = a*A14*B24 + 1*C12 now we are done with C12
451
+ unsigned int offa_A8 = args.lda *args.K / 4 * 3 ;
452
+ unsigned int offb_B8 = args.ldb *args.K / 4 * 3 + args.N / 2 ;
453
+
454
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A8);
455
+ assert (error == CL_SUCCESS);
456
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B8);
457
+ assert (error == CL_SUCCESS);
458
+
459
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
460
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
461
+ assert (error == CL_SUCCESS);
462
+
463
+ // GEMM #9: C21 = a*A21*B11 + b*C21
464
+ unsigned int offa_A9 = args.M / 2 ;
465
+ unsigned int offb_B9 = 0 ;
466
+ unsigned int offc_C9 = args.M / 2 ;
467
+
468
+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
469
+ assert (error == CL_SUCCESS);
470
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A9);
471
+ assert (error == CL_SUCCESS);
472
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B9);
473
+ assert (error == CL_SUCCESS);
474
+ error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C9);
475
+ assert (error == CL_SUCCESS);
476
+
477
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
478
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
479
+ assert (error == CL_SUCCESS);
480
+
481
+ // GEMM #10: C21 = a*A22*B12 + 1*C21
482
+
483
+ unsigned int offa_A10 = args.lda *args.K / 4 + args.M / 2 ;
484
+ unsigned int offb_B10 = args.ldb *args.K / 4 ;
485
+
486
+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
487
+ assert (error == CL_SUCCESS);
488
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A10);
489
+ assert (error == CL_SUCCESS);
490
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B10);
491
+ assert (error == CL_SUCCESS);
492
+
493
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
494
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
495
+ assert (error == CL_SUCCESS);
496
+
497
+ // GEMM #11: C21 = a*A23*B13 + 1*C21
498
+
499
+ unsigned int offa_A11 = args.lda *args.K / 4 * 2 + args.M / 2 ;
500
+ unsigned int offb_B11 = args.ldb *args.K / 4 * 2 ;
501
+
502
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A11);
503
+ assert (error == CL_SUCCESS);
504
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B11);
505
+ assert (error == CL_SUCCESS);
506
+
507
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
508
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
509
+ assert (error == CL_SUCCESS);
510
+
511
+ // GEMM #12: C21 = a*A24*B14 + 1*C21 now we are done with C21
512
+
513
+ unsigned int offa_A12 = args.lda *args.K / 4 * 3 + args.M / 2 ;
514
+ unsigned int offb_B12 = args.ldb *args.K / 4 * 3 ;
515
+
516
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A12);
517
+ assert (error == CL_SUCCESS);
518
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B12);
519
+ assert (error == CL_SUCCESS);
520
+
521
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
522
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
523
+ assert (error == CL_SUCCESS);
524
+
525
+ // GEMM #13: C22 = a*A21*B21 + b*C22
526
+ unsigned int offa_A13 = args.M / 2 ;
527
+ unsigned int offb_B13 = args.N / 2 ;
528
+ unsigned int offc_C13 = args.ldc *args.N / 2 + args.M / 2 ;
529
+
530
+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
531
+ assert (error == CL_SUCCESS);
532
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A13);
533
+ assert (error == CL_SUCCESS);
534
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B13);
535
+ assert (error == CL_SUCCESS);
536
+ error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C13);
537
+ assert (error == CL_SUCCESS);
538
+
539
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
540
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
541
+ assert (error == CL_SUCCESS);
542
+
543
+ // #14: C22 = a*A22*B22 + 1*C22
544
+ unsigned int offa_A14 = args.lda *args.K / 4 + args.M / 2 ;
545
+ unsigned int offb_B14 = args.ldb *args.K / 4 + args.N / 2 ;
546
+
547
+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
548
+ assert (error == CL_SUCCESS);
549
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A14);
550
+ assert (error == CL_SUCCESS);
551
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B14);
552
+ assert (error == CL_SUCCESS);
553
+
554
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
555
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
556
+ assert (error == CL_SUCCESS);
557
+
558
+ // #15: C22 = a*A23*B23 + 1*C22
559
+ unsigned int offa_A15 = args.lda *args.K / 4 * 2 + args.M / 2 ;
560
+ unsigned int offb_B15 = args.ldb *args.K / 4 * 2 + args.N / 2 ;
561
+
562
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A15);
563
+ assert (error == CL_SUCCESS);
564
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B15);
565
+ assert (error == CL_SUCCESS);
566
+
567
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
568
+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
569
+ assert (error == CL_SUCCESS);
570
+
571
+ // #16: C22 = a*A24*B24 + 1*C22
572
+ unsigned int offa_A16 = args.lda *args.K / 4 * 3 + args.M / 2 ;
573
+ unsigned int offb_B16 = args.ldb *args.K / 4 * 3 + args.N / 2 ;
574
+
575
+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A16);
576
+ assert (error == CL_SUCCESS);
577
+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B16);
578
+ assert (error == CL_SUCCESS);
579
+
580
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
581
+ gs, m_variantBig1024->ls , 0 , NULL , args.events );
582
+ assert (error == CL_SUCCESS);
583
+
584
+ return error;
585
+ }
307
586
308
587
309
588
return clblasNotImplemented;
0 commit comments