@@ -335,251 +335,90 @@ cl_int clBlashawaiiSgemmBig1024KernelFunctor::KernelsLaunch(cl_command_queue que
335335 // #15: C22 = a*A23*B23 + 1*C22
336336 // #16: C22 = a*A24*B24 + 1*C22 now we are done with C22
337337
338- unsigned int small_M = args.M / 2 ;
339- unsigned int small_N = args.N / 2 ;
340- unsigned int small_K = args.K / 4 ;
338+ unsigned int K_split_factor = 4 ;
339+ unsigned int M_split_factor = 2 ;
340+ unsigned int N_split_factor = 2 ;
341+
342+ unsigned int small_M = args.M / M_split_factor;
343+ unsigned int small_N = args.N / N_split_factor;
344+ unsigned int small_K = args.K / K_split_factor;
341345
342346 size_t GlobalX = ((small_M - 1 ) / 128 + 1 ) * 16 ;
343347 size_t GlobalY = ((small_N - 1 ) / 128 + 1 ) * 16 ;
344348 std::size_t gs[2 ] = { GlobalX, GlobalY };
345349 cl_int error = 0 ;
346350
347- // GEMM #1
351+ cl_float betaone = 1 ;
352+
348353 error = clSetKernelArg (Kernel[0 ], 3 , sizeof (cl_uint), &small_M);
349354 assert (error == CL_SUCCESS);
350355 error = clSetKernelArg (Kernel[0 ], 4 , sizeof (cl_uint), &small_N);
351356 assert (error == CL_SUCCESS);
352357 error = clSetKernelArg (Kernel[0 ], 5 , sizeof (cl_uint), &small_K);
353358 assert (error == CL_SUCCESS);
354359
355- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
356- gs, m_variantBig1024->ls , args.numEventsInWaitList , args.eventWaitList , NULL );
357- assert (error == CL_SUCCESS);
358-
359- // GEMM #2: C11 = a*A12*B12 + 1*C11
360-
361- unsigned int offa_A2 = args.lda *args.K / 4 ;
362- unsigned int offb_B2 = args.ldb *args.K / 4 ;
363- cl_float betaone = 1 ;
364-
365- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
366- assert (error == CL_SUCCESS);
367- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A2);
368- assert (error == CL_SUCCESS);
369- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B2);
370- assert (error == CL_SUCCESS);
371-
372- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
373- gs, m_variantBig1024->ls , 0 , NULL , NULL );
374- assert (error == CL_SUCCESS);
375-
376- // GEMM #3: C11 = a*A13*B13 + 1*C11
377-
378- unsigned int offa_A3 = args.lda *args.K / 4 * 2 ;
379- unsigned int offb_B3 = args.ldb *args.K / 4 * 2 ;
380-
381- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A3);
382- assert (error == CL_SUCCESS);
383- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B3);
384- assert (error == CL_SUCCESS);
385-
386- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
387- gs, m_variantBig1024->ls , 0 , NULL , NULL );
388- assert (error == CL_SUCCESS);
389-
390- // GEMM #4: C11 = a*A14*B14 + 1*C11 now we are done with C11
391-
392- unsigned int offa_A4 = args.lda *args.K / 4 * 3 ;
393- unsigned int offb_B4 = args.ldb *args.K / 4 * 3 ;
394-
395- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A4);
396- assert (error == CL_SUCCESS);
397- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B4);
398- assert (error == CL_SUCCESS);
399-
400- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
401- gs, m_variantBig1024->ls , 0 , NULL , NULL );
402- assert (error == CL_SUCCESS);
403-
404- // GEMM #5: C12 = a*A11*B21 + b*C12
405- unsigned int offa_A5 = 0 ;
406- unsigned int offb_B5 = args.N / 2 ;
407- unsigned int offc_C5 = args.ldc *args.N / 2 ;
408-
409- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
410- assert (error == CL_SUCCESS);
411- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A5);
412- assert (error == CL_SUCCESS);
413- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B5);
414- assert (error == CL_SUCCESS);
415- error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C5);
416- assert (error == CL_SUCCESS);
417-
418- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
419- gs, m_variantBig1024->ls , 0 , NULL , NULL );
420- assert (error == CL_SUCCESS);
421-
422- // GEMM #6: C12 = a*A12*B22 + 1*C12
423- unsigned int offa_A6 = args.lda *args.K / 4 ;
424- unsigned int offb_B6 = args.ldb *args.K / 4 + args.N / 2 ;
425-
426- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
427- assert (error == CL_SUCCESS);
428- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A6);
429- assert (error == CL_SUCCESS);
430- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B6);
431- assert (error == CL_SUCCESS);
432-
433- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
434- gs, m_variantBig1024->ls , 0 , NULL , NULL );
435- assert (error == CL_SUCCESS);
436-
437- // GEMM #7: C12 = a*A13*B23 + 1*C12
438- unsigned int offa_A7 = args.lda *args.K / 4 * 2 ;
439- unsigned int offb_B7 = args.ldb *args.K / 4 * 2 + args.N / 2 ;
440-
441- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A7);
442- assert (error == CL_SUCCESS);
443- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B7);
444- assert (error == CL_SUCCESS);
445-
446- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
447- gs, m_variantBig1024->ls , 0 , NULL , NULL );
448- assert (error == CL_SUCCESS);
449-
450- // GEMM #8: C12 = a*A14*B24 + 1*C12 now we are done with C12
451- unsigned int offa_A8 = args.lda *args.K / 4 * 3 ;
452- unsigned int offb_B8 = args.ldb *args.K / 4 * 3 + args.N / 2 ;
453-
454- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A8);
455- assert (error == CL_SUCCESS);
456- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B8);
457- assert (error == CL_SUCCESS);
458-
459- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
460- gs, m_variantBig1024->ls , 0 , NULL , NULL );
461- assert (error == CL_SUCCESS);
462-
463- // GEMM #9: C21 = a*A21*B11 + b*C21
464- unsigned int offa_A9 = args.M / 2 ;
465- unsigned int offb_B9 = 0 ;
466- unsigned int offc_C9 = args.M / 2 ;
467-
468- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
469- assert (error == CL_SUCCESS);
470- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A9);
471- assert (error == CL_SUCCESS);
472- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B9);
473- assert (error == CL_SUCCESS);
474- error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C9);
475- assert (error == CL_SUCCESS);
476-
477- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
478- gs, m_variantBig1024->ls , 0 , NULL , NULL );
479- assert (error == CL_SUCCESS);
480-
481- // GEMM #10: C21 = a*A22*B12 + 1*C21
482-
483- unsigned int offa_A10 = args.lda *args.K / 4 + args.M / 2 ;
484- unsigned int offb_B10 = args.ldb *args.K / 4 ;
485-
486- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
487- assert (error == CL_SUCCESS);
488- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A10);
489- assert (error == CL_SUCCESS);
490- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B10);
491- assert (error == CL_SUCCESS);
492-
493- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
494- gs, m_variantBig1024->ls , 0 , NULL , NULL );
495- assert (error == CL_SUCCESS);
496-
497- // GEMM #11: C21 = a*A23*B13 + 1*C21
498-
499- unsigned int offa_A11 = args.lda *args.K / 4 * 2 + args.M / 2 ;
500- unsigned int offb_B11 = args.ldb *args.K / 4 * 2 ;
501-
502- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A11);
503- assert (error == CL_SUCCESS);
504- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B11);
505- assert (error == CL_SUCCESS);
506-
507- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
508- gs, m_variantBig1024->ls , 0 , NULL , NULL );
509- assert (error == CL_SUCCESS);
510-
511- // GEMM #12: C21 = a*A24*B14 + 1*C21 now we are done with C21
512-
513- unsigned int offa_A12 = args.lda *args.K / 4 * 3 + args.M / 2 ;
514- unsigned int offb_B12 = args.ldb *args.K / 4 * 3 ;
515-
516- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A12);
517- assert (error == CL_SUCCESS);
518- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B12);
519- assert (error == CL_SUCCESS);
520-
521- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
522- gs, m_variantBig1024->ls , 0 , NULL , NULL );
523- assert (error == CL_SUCCESS);
524-
525- // GEMM #13: C22 = a*A21*B21 + b*C22
526- unsigned int offa_A13 = args.M / 2 ;
527- unsigned int offb_B13 = args.N / 2 ;
528- unsigned int offc_C13 = args.ldc *args.N / 2 + args.M / 2 ;
529-
530- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
531- assert (error == CL_SUCCESS);
532- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A13);
533- assert (error == CL_SUCCESS);
534- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B13);
535- assert (error == CL_SUCCESS);
536- error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C13);
537- assert (error == CL_SUCCESS);
538-
539- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
540- gs, m_variantBig1024->ls , 0 , NULL , NULL );
541- assert (error == CL_SUCCESS);
542-
543- // #14: C22 = a*A22*B22 + 1*C22
544- unsigned int offa_A14 = args.lda *args.K / 4 + args.M / 2 ;
545- unsigned int offb_B14 = args.ldb *args.K / 4 + args.N / 2 ;
546-
547- error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
548- assert (error == CL_SUCCESS);
549- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A14);
550- assert (error == CL_SUCCESS);
551- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B14);
552- assert (error == CL_SUCCESS);
553-
554- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
555- gs, m_variantBig1024->ls , 0 , NULL , NULL );
556- assert (error == CL_SUCCESS);
557-
558- // #15: C22 = a*A23*B23 + 1*C22
559- unsigned int offa_A15 = args.lda *args.K / 4 * 2 + args.M / 2 ;
560- unsigned int offb_B15 = args.ldb *args.K / 4 * 2 + args.N / 2 ;
561-
562- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A15);
563- assert (error == CL_SUCCESS);
564- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B15);
565- assert (error == CL_SUCCESS);
566-
567- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
568- gs, m_variantBig1024->ls , 0 , NULL , NULL );
569- assert (error == CL_SUCCESS);
570-
571- // #16: C22 = a*A24*B24 + 1*C22
572- unsigned int offa_A16 = args.lda *args.K / 4 * 3 + args.M / 2 ;
573- unsigned int offb_B16 = args.ldb *args.K / 4 * 3 + args.N / 2 ;
574-
575- error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A16);
576- assert (error == CL_SUCCESS);
577- error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B16);
578- assert (error == CL_SUCCESS);
360+ for (int M_split_index = 0 ; M_split_index < M_split_factor; M_split_index++)
361+ {
362+ // 2 groups of GEMMs splited by M from example
363+ for (int N_split_index = 0 ; N_split_index < N_split_factor; N_split_index++)
364+ {
365+ // 2 groups of GEMMs splited by N from example
366+ unsigned int offc_C = args.ldc *args.N / N_split_factor * N_split_index + args.M / M_split_factor * M_split_index + args.offC ;
367+ error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C);
368+ assert (error == CL_SUCCESS);
369+
370+ for (int K_split_index = 0 ; K_split_index < K_split_factor; K_split_index++)
371+ {
372+ // 4 GEMMs splited by K from example
373+ unsigned int offa_A = (args.M / M_split_factor * M_split_index) + (args.lda * args.K / K_split_factor * K_split_index) + args.offA ;
374+ unsigned int offb_B = (args.N / N_split_factor * N_split_index) + (args.ldb * args.K / K_split_factor * K_split_index) + args.offB ;
375+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A);
376+ assert (error == CL_SUCCESS);
377+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B);
378+ assert (error == CL_SUCCESS);
379+
380+ if (K_split_index == 0 )
381+ {
382+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
383+ assert (error == CL_SUCCESS);
384+
385+ if (M_split_index == 0 && N_split_index == 0 )
386+ {
387+ // very first GEMM call
388+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
389+ gs, m_variantBig1024->ls , args.numEventsInWaitList , args.eventWaitList , NULL );
390+ assert (error == CL_SUCCESS);
391+ }
392+ else
393+ {
394+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
395+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
396+ assert (error == CL_SUCCESS);
397+ }
398+ }
399+ else
400+ {
401+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
402+ assert (error == CL_SUCCESS);
403+
404+ if ((M_split_index == (M_split_factor - 1 ) ) && (N_split_index == (N_split_factor - 1 )) && (K_split_index == (K_split_factor - 1 )))
405+ {
406+ // very last GEMM call
407+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
408+ gs, m_variantBig1024->ls , 0 , NULL , args.events );
409+ assert (error == CL_SUCCESS);
410+ }
411+ else
412+ {
413+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
414+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
415+ assert (error == CL_SUCCESS);
416+ }
417+ }
418+ }
419+ }
420+ }
579421
580- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
581- gs, m_variantBig1024->ls , 0 , NULL , args.events );
582- assert (error == CL_SUCCESS);
583422
584423 return error;
585424 }
0 commit comments