@@ -289,21 +289,300 @@ clBlashawaiiSgemmBig1024KernelFunctor::provide(clblasSgemmFunctor::Args & args,
289289
290290cl_int clBlashawaiiSgemmBig1024KernelFunctor::KernelsLaunch (cl_command_queue queue, cl_kernel Kernel[1 ], Args &args)
291291{
292- // ((Mvalue - 1) / 128 + 1) * 16
293- size_t GlobalX = ((args.M -1 ) / 128 + 1 )*16 ;
294-
295- //
296-
297- size_t GlobalY = ((args.N - 1 ) / 128 + 1 ) * 16 ;
292+ if (args.lda < 7168 )
293+ {
294+ // ((Mvalue - 1) / 128 + 1) * 16
295+ size_t GlobalX = ((args.M - 1 ) / 128 + 1 ) * 16 ;
296+ size_t GlobalY = ((args.N - 1 ) / 128 + 1 ) * 16 ;
298297
299298
300- std::size_t gs[2 ] = {GlobalX, GlobalY};
301- cl_int error = 0 ;
299+ std::size_t gs[2 ] = { GlobalX, GlobalY };
300+ cl_int error = 0 ;
302301
303302
304- // if (VERB) printf(" ===> EXECUTE KERNEL 0 \n") ;
305- error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantBig1024->ls , args.numEventsInWaitList , args.eventWaitList ,args.events );
306- return error;
303+ // if (VERB) printf(" ===> EXECUTE KERNEL 0 \n") ;
304+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantBig1024->ls , args.numEventsInWaitList , args.eventWaitList , args.events );
305+ return error;
306+ }
307+ else
308+ {
309+ // for example, when M=N=K=8192
310+ // we are gonna call 16 GEMMs
311+ // each GEMM has M=N=K=4096
312+ // note are direct GEMM call has a 0.7 TFLOPS performance
313+
314+ // [ A11 | A12 | A13 | A14 ] [ B11 | B12 | B13 | B14 ] [ C11 | C12 ]
315+ // A = [ A21 | A22 | A23 | A24 ] B = [ B21 | B22 | B23 | B24 ] C = [ C21 | C22 ]
316+
317+ // 16 GEMMs are
318+ // #01: C11 = a*A11*B11 + b*C11
319+ // #02: C11 = a*A12*B12 + 1*C11
320+ // #03: C11 = a*A13*B13 + 1*C11
321+ // #04: C11 = a*A14*B14 + 1*C11 now we are done with C11
322+
323+ // #05: C12 = a*A11*B21 + b*C12
324+ // #06: C12 = a*A12*B22 + 1*C12
325+ // #07: C12 = a*A12*B22 + 1*C12
326+ // #08: C12 = a*A12*B22 + 1*C12 now we are done with C12
327+
328+ // #09: C21 = a*A21*B11 + b*C21
329+ // #10: C21 = a*A22*B12 + 1*C21
330+ // #11: C21 = a*A23*B13 + 1*C21
331+ // #12: C21 = a*A24*B14 + 1*C21 now we are done with C21
332+
333+ // #13: C22 = a*A21*B21 + b*C22
334+ // #14: C22 = a*A22*B22 + 1*C22
335+ // #15: C22 = a*A23*B23 + 1*C22
336+ // #16: C22 = a*A24*B24 + 1*C22 now we are done with C22
337+
338+ unsigned int small_M = args.M / 2 ;
339+ unsigned int small_N = args.N / 2 ;
340+ unsigned int small_K = args.K / 4 ;
341+
342+ size_t GlobalX = ((small_M - 1 ) / 128 + 1 ) * 16 ;
343+ size_t GlobalY = ((small_N - 1 ) / 128 + 1 ) * 16 ;
344+ std::size_t gs[2 ] = { GlobalX, GlobalY };
345+ cl_int error = 0 ;
346+
347+ // GEMM #1
348+ error = clSetKernelArg (Kernel[0 ], 3 , sizeof (cl_uint), &small_M);
349+ assert (error == CL_SUCCESS);
350+ error = clSetKernelArg (Kernel[0 ], 4 , sizeof (cl_uint), &small_N);
351+ assert (error == CL_SUCCESS);
352+ error = clSetKernelArg (Kernel[0 ], 5 , sizeof (cl_uint), &small_K);
353+ assert (error == CL_SUCCESS);
354+
355+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
356+ gs, m_variantBig1024->ls , args.numEventsInWaitList , args.eventWaitList , NULL );
357+ assert (error == CL_SUCCESS);
358+
359+ // GEMM #2: C11 = a*A12*B12 + 1*C11
360+
361+ unsigned int offa_A2 = args.lda *args.K / 4 ;
362+ unsigned int offb_B2 = args.ldb *args.K / 4 ;
363+ cl_float betaone = 1 ;
364+
365+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
366+ assert (error == CL_SUCCESS);
367+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A2);
368+ assert (error == CL_SUCCESS);
369+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B2);
370+ assert (error == CL_SUCCESS);
371+
372+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
373+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
374+ assert (error == CL_SUCCESS);
375+
376+ // GEMM #3: C11 = a*A13*B13 + 1*C11
377+
378+ unsigned int offa_A3 = args.lda *args.K / 4 * 2 ;
379+ unsigned int offb_B3 = args.ldb *args.K / 4 * 2 ;
380+
381+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A3);
382+ assert (error == CL_SUCCESS);
383+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B3);
384+ assert (error == CL_SUCCESS);
385+
386+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
387+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
388+ assert (error == CL_SUCCESS);
389+
390+ // GEMM #4: C11 = a*A14*B14 + 1*C11 now we are done with C11
391+
392+ unsigned int offa_A4 = args.lda *args.K / 4 * 3 ;
393+ unsigned int offb_B4 = args.ldb *args.K / 4 * 3 ;
394+
395+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A4);
396+ assert (error == CL_SUCCESS);
397+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B4);
398+ assert (error == CL_SUCCESS);
399+
400+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
401+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
402+ assert (error == CL_SUCCESS);
403+
404+ // GEMM #5: C12 = a*A11*B21 + b*C12
405+ unsigned int offa_A5 = 0 ;
406+ unsigned int offb_B5 = args.N / 2 ;
407+ unsigned int offc_C5 = args.ldc *args.N / 2 ;
408+
409+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
410+ assert (error == CL_SUCCESS);
411+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A5);
412+ assert (error == CL_SUCCESS);
413+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B5);
414+ assert (error == CL_SUCCESS);
415+ error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C5);
416+ assert (error == CL_SUCCESS);
417+
418+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
419+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
420+ assert (error == CL_SUCCESS);
421+
422+ // GEMM #6: C12 = a*A12*B22 + 1*C12
423+ unsigned int offa_A6 = args.lda *args.K / 4 ;
424+ unsigned int offb_B6 = args.ldb *args.K / 4 + args.N / 2 ;
425+
426+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
427+ assert (error == CL_SUCCESS);
428+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A6);
429+ assert (error == CL_SUCCESS);
430+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B6);
431+ assert (error == CL_SUCCESS);
432+
433+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
434+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
435+ assert (error == CL_SUCCESS);
436+
437+ // GEMM #7: C12 = a*A13*B23 + 1*C12
438+ unsigned int offa_A7 = args.lda *args.K / 4 * 2 ;
439+ unsigned int offb_B7 = args.ldb *args.K / 4 * 2 + args.N / 2 ;
440+
441+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A7);
442+ assert (error == CL_SUCCESS);
443+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B7);
444+ assert (error == CL_SUCCESS);
445+
446+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
447+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
448+ assert (error == CL_SUCCESS);
449+
450+ // GEMM #8: C12 = a*A14*B24 + 1*C12 now we are done with C12
451+ unsigned int offa_A8 = args.lda *args.K / 4 * 3 ;
452+ unsigned int offb_B8 = args.ldb *args.K / 4 * 3 + args.N / 2 ;
453+
454+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A8);
455+ assert (error == CL_SUCCESS);
456+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B8);
457+ assert (error == CL_SUCCESS);
458+
459+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
460+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
461+ assert (error == CL_SUCCESS);
462+
463+ // GEMM #9: C21 = a*A21*B11 + b*C21
464+ unsigned int offa_A9 = args.M / 2 ;
465+ unsigned int offb_B9 = 0 ;
466+ unsigned int offc_C9 = args.M / 2 ;
467+
468+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
469+ assert (error == CL_SUCCESS);
470+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A9);
471+ assert (error == CL_SUCCESS);
472+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B9);
473+ assert (error == CL_SUCCESS);
474+ error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C9);
475+ assert (error == CL_SUCCESS);
476+
477+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
478+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
479+ assert (error == CL_SUCCESS);
480+
481+ // GEMM #10: C21 = a*A22*B12 + 1*C21
482+
483+ unsigned int offa_A10 = args.lda *args.K / 4 + args.M / 2 ;
484+ unsigned int offb_B10 = args.ldb *args.K / 4 ;
485+
486+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
487+ assert (error == CL_SUCCESS);
488+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A10);
489+ assert (error == CL_SUCCESS);
490+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B10);
491+ assert (error == CL_SUCCESS);
492+
493+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
494+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
495+ assert (error == CL_SUCCESS);
496+
497+ // GEMM #11: C21 = a*A23*B13 + 1*C21
498+
499+ unsigned int offa_A11 = args.lda *args.K / 4 * 2 + args.M / 2 ;
500+ unsigned int offb_B11 = args.ldb *args.K / 4 * 2 ;
501+
502+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A11);
503+ assert (error == CL_SUCCESS);
504+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B11);
505+ assert (error == CL_SUCCESS);
506+
507+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
508+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
509+ assert (error == CL_SUCCESS);
510+
511+ // GEMM #12: C21 = a*A24*B14 + 1*C21 now we are done with C21
512+
513+ unsigned int offa_A12 = args.lda *args.K / 4 * 3 + args.M / 2 ;
514+ unsigned int offb_B12 = args.ldb *args.K / 4 * 3 ;
515+
516+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A12);
517+ assert (error == CL_SUCCESS);
518+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B12);
519+ assert (error == CL_SUCCESS);
520+
521+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
522+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
523+ assert (error == CL_SUCCESS);
524+
525+ // GEMM #13: C22 = a*A21*B21 + b*C22
526+ unsigned int offa_A13 = args.M / 2 ;
527+ unsigned int offb_B13 = args.N / 2 ;
528+ unsigned int offc_C13 = args.ldc *args.N / 2 + args.M / 2 ;
529+
530+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &(args.beta ));
531+ assert (error == CL_SUCCESS);
532+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A13);
533+ assert (error == CL_SUCCESS);
534+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B13);
535+ assert (error == CL_SUCCESS);
536+ error = clSetKernelArg (Kernel[0 ], 13 , sizeof (cl_uint), &offc_C13);
537+ assert (error == CL_SUCCESS);
538+
539+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
540+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
541+ assert (error == CL_SUCCESS);
542+
543+ // #14: C22 = a*A22*B22 + 1*C22
544+ unsigned int offa_A14 = args.lda *args.K / 4 + args.M / 2 ;
545+ unsigned int offb_B14 = args.ldb *args.K / 4 + args.N / 2 ;
546+
547+ error = clSetKernelArg (Kernel[0 ], 7 , sizeof (cl_float), &betaone);
548+ assert (error == CL_SUCCESS);
549+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A14);
550+ assert (error == CL_SUCCESS);
551+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B14);
552+ assert (error == CL_SUCCESS);
553+
554+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
555+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
556+ assert (error == CL_SUCCESS);
557+
558+ // #15: C22 = a*A23*B23 + 1*C22
559+ unsigned int offa_A15 = args.lda *args.K / 4 * 2 + args.M / 2 ;
560+ unsigned int offb_B15 = args.ldb *args.K / 4 * 2 + args.N / 2 ;
561+
562+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A15);
563+ assert (error == CL_SUCCESS);
564+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B15);
565+ assert (error == CL_SUCCESS);
566+
567+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
568+ gs, m_variantBig1024->ls , 0 , NULL , NULL );
569+ assert (error == CL_SUCCESS);
570+
571+ // #16: C22 = a*A24*B24 + 1*C22
572+ unsigned int offa_A16 = args.lda *args.K / 4 * 3 + args.M / 2 ;
573+ unsigned int offb_B16 = args.ldb *args.K / 4 * 3 + args.N / 2 ;
574+
575+ error = clSetKernelArg (Kernel[0 ], 11 , sizeof (cl_uint), &offa_A16);
576+ assert (error == CL_SUCCESS);
577+ error = clSetKernelArg (Kernel[0 ], 12 , sizeof (cl_uint), &offb_B16);
578+ assert (error == CL_SUCCESS);
579+
580+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL ,
581+ gs, m_variantBig1024->ls , 0 , NULL , args.events );
582+ assert (error == CL_SUCCESS);
583+
584+ return error;
585+ }
307586
308587
309588 return clblasNotImplemented;
0 commit comments