@@ -683,6 +683,153 @@ cl_int clBlashawaiiSgemmSplitKernelFunctor::KernelsLaunch(cl_command_queue queue
683683
684684 std::size_t gs[2 ] = {GlobalX, GlobalY};
685685 cl_int error = 0 ;
686+
687+ // deals with square matrix sizes where K is mod 16 for now
688+ if (args.lda == args.ldb )
689+ {
690+ if ((args.K % 16 == 0 ) && (args.lda >= 6144 ) && (args.ldb >= 6144 ))
691+ {
692+ if ((args.lda % 1024 == 0 ) && (args.ldb % 1024 == 0 ) && (args.transA == clblasNoTrans) && (args.transB == clblasTrans))
693+ {
694+ // handles special cases where a direct call to "sgemm_NT_96_96_16..." causes perf drop due to cache miss/thrashing
695+ // this special cases is: sgemm column major NT / sgemm row major TN; lda and ldb are big multiples of 1024 such as 4096 and 6144
696+ // K is bigger than a threshold: 1536 for lda=ldb=6144
697+
698+ //
699+ int K_block_size;
700+ if (args.lda == 6144 )
701+ {
702+ K_block_size = 1536 ;
703+ }
704+ else
705+ {
706+ K_block_size = 128 ;
707+ }
708+
709+ if (args.M % 96 == 0 && args.N % 96 == 0 )
710+ {
711+ if (VERB) printf (" ===> EXECUTE KERNEL 0 \n " );
712+ if (args.K > K_block_size)
713+ {
714+ // split into many GEMM calls with K = K_block_size
715+ // there are at least 2 GEMM calls
716+ int num_of_gemm = ((args.K - 1 ) / K_block_size) + 1 ;
717+
718+ // call first GEMM
719+ unsigned int small_K = K_block_size;
720+ setKernelArg<int >(Kernel[0 ], 5 , small_K);
721+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantSplit->ls , args.numEventsInWaitList , args.eventWaitList , NULL );
722+
723+ // call middle GEMMs
724+ unsigned beta_one = 1 .0f ;
725+ setKernelArg<int >(Kernel[0 ], 7 , beta_one);
726+ for (int i = 1 ; i < num_of_gemm - 1 ; i++)
727+ {
728+ unsigned offa_i = args.lda * (args.K / num_of_gemm) * i + args.offA ;
729+ unsigned offb_i = args.ldb * (args.K / num_of_gemm) * i + args.offB ;
730+ setKernelArg<int >(Kernel[0 ], 11 , offa_i);
731+ setKernelArg<int >(Kernel[0 ], 12 , offb_i);
732+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
733+ }
734+ // call last GEMM
735+ // the last GEMM's K might be smaller than small_K
736+ unsigned int residue_K = args.K % small_K;
737+ if (residue_K == 0 )
738+ residue_K = small_K;
739+ unsigned offa_i = args.lda * (args.K / num_of_gemm) * (num_of_gemm - 1 ) + args.offA ;
740+ unsigned offb_i = args.ldb * (args.K / num_of_gemm) * (num_of_gemm - 1 ) + args.offB ;
741+ setKernelArg<int >(Kernel[0 ], 5 , residue_K);
742+ setKernelArg<int >(Kernel[0 ], 11 , offa_i);
743+ setKernelArg<int >(Kernel[0 ], 12 , offb_i);
744+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , args.events );
745+ return error;
746+ }
747+ }
748+
749+ if (args.M % 96 != 0 && args.N % 96 != 0 && args.M >= 96 && args.N >= 96 )
750+ {
751+ if (VERB) printf (" ===> EXECUTE KERNEL 0, 1, 2, 3 \n " );
752+
753+ if (args.K > K_block_size)
754+ {
755+ int num_of_gemm = ((args.K - 1 ) / K_block_size) + 1 ;
756+
757+ // first 4 GEMMs
758+ unsigned int small_K = K_block_size;
759+ setKernelArg<int >(Kernel[0 ], 5 , small_K);
760+
761+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantSplit->ls , args.numEventsInWaitList , args.eventWaitList , NULL );
762+
763+ gs[0 ] = 16 ;
764+ error |= clEnqueueNDRangeKernel (queue, Kernel[1 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
765+
766+ gs[1 ] = 16 ;
767+ gs[0 ] = GlobalX;
768+ error |= clEnqueueNDRangeKernel (queue, Kernel[2 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
769+
770+ gs[0 ] = 16 ; gs[1 ] = 16 ;
771+ error |= clEnqueueNDRangeKernel (queue, Kernel[3 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
772+
773+ // middle GEMMs
774+ unsigned beta_one = 1 .0f ;
775+ setKernelArg<int >(Kernel[0 ], 7 , beta_one);
776+ for (int i = 1 ; i < num_of_gemm - 1 ; i++)
777+ {
778+ unsigned offa_i = args.lda * (args.K / num_of_gemm) * i + args.offA ;
779+ unsigned offb_i = args.ldb * (args.K / num_of_gemm) * i + args.offB ;
780+ setKernelArg<int >(Kernel[0 ], 11 , offa_i);
781+ setKernelArg<int >(Kernel[0 ], 12 , offb_i);
782+ // gs[2] = {GlobalX, GlobalY};
783+ gs[0 ] = GlobalX;
784+ gs[1 ] = GlobalY;
785+
786+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
787+
788+ gs[0 ] = 16 ;
789+ error |= clEnqueueNDRangeKernel (queue, Kernel[1 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
790+
791+ gs[1 ] = 16 ;
792+ gs[0 ] = GlobalX;
793+ error |= clEnqueueNDRangeKernel (queue, Kernel[2 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
794+
795+ gs[0 ] = 16 ; gs[1 ] = 16 ;
796+ error |= clEnqueueNDRangeKernel (queue, Kernel[3 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
797+ }
798+ // last 4 GEMMs
799+ unsigned int residue_K = args.K % small_K;
800+ if (residue_K == 0 )
801+ residue_K = small_K;
802+ unsigned offa_i = args.lda * (args.K / num_of_gemm) * (num_of_gemm - 1 ) + args.offA ;
803+ unsigned offb_i = args.ldb * (args.K / num_of_gemm) * (num_of_gemm - 1 ) + args.offB ;
804+ setKernelArg<int >(Kernel[0 ], 5 , residue_K);
805+ setKernelArg<int >(Kernel[0 ], 11 , offa_i);
806+ setKernelArg<int >(Kernel[0 ], 12 , offb_i);
807+
808+ gs[0 ] = GlobalX;
809+ gs[1 ] = GlobalY;
810+
811+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
812+
813+ gs[0 ] = 16 ;
814+ error |= clEnqueueNDRangeKernel (queue, Kernel[1 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
815+
816+ gs[1 ] = 16 ;
817+ gs[0 ] = GlobalX;
818+ error |= clEnqueueNDRangeKernel (queue, Kernel[2 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
819+
820+ gs[0 ] = 16 ; gs[1 ] = 16 ;
821+ error |= clEnqueueNDRangeKernel (queue, Kernel[3 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , args.events );
822+
823+
824+ return error;
825+ }
826+ }
827+
828+
829+ }
830+ }
831+ }
832+
686833
687834 if (args.M %96 ==0 && args.N %96 ==0 )
688835 {
0 commit comments