@@ -683,6 +683,153 @@ cl_int clBlashawaiiSgemmSplitKernelFunctor::KernelsLaunch(cl_command_queue queue
683
683
684
684
std::size_t gs[2 ] = {GlobalX, GlobalY};
685
685
cl_int error = 0 ;
686
+
687
+ // deals with square matrix sizes where K is mod 16 for now
688
+ if (args.lda == args.ldb )
689
+ {
690
+ if ((args.K % 16 == 0 ) && (args.lda >= 6144 ) && (args.ldb >= 6144 ))
691
+ {
692
+ if ((args.lda % 1024 == 0 ) && (args.ldb % 1024 == 0 ) && (args.transA == clblasNoTrans) && (args.transB == clblasTrans))
693
+ {
694
+ // handles special cases where a direct call to "sgemm_NT_96_96_16..." causes perf drop due to cache miss/thrashing
695
+ // this special cases is: sgemm column major NT / sgemm row major TN; lda and ldb are big multiples of 1024 such as 4096 and 6144
696
+ // K is bigger than a threshold: 1536 for lda=ldb=6144
697
+
698
+ //
699
+ int K_block_size;
700
+ if (args.lda == 6144 )
701
+ {
702
+ K_block_size = 1536 ;
703
+ }
704
+ else
705
+ {
706
+ K_block_size = 128 ;
707
+ }
708
+
709
+ if (args.M % 96 == 0 && args.N % 96 == 0 )
710
+ {
711
+ if (VERB) printf (" ===> EXECUTE KERNEL 0 \n " );
712
+ if (args.K > K_block_size)
713
+ {
714
+ // split into many GEMM calls with K = K_block_size
715
+ // there are at least 2 GEMM calls
716
+ int num_of_gemm = ((args.K - 1 ) / K_block_size) + 1 ;
717
+
718
+ // call first GEMM
719
+ unsigned int small_K = K_block_size;
720
+ setKernelArg<int >(Kernel[0 ], 5 , small_K);
721
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantSplit->ls , args.numEventsInWaitList , args.eventWaitList , NULL );
722
+
723
+ // call middle GEMMs
724
+ unsigned beta_one = 1 .0f ;
725
+ setKernelArg<int >(Kernel[0 ], 7 , beta_one);
726
+ for (int i = 1 ; i < num_of_gemm - 1 ; i++)
727
+ {
728
+ unsigned offa_i = args.lda * (args.K / num_of_gemm) * i + args.offA ;
729
+ unsigned offb_i = args.ldb * (args.K / num_of_gemm) * i + args.offB ;
730
+ setKernelArg<int >(Kernel[0 ], 11 , offa_i);
731
+ setKernelArg<int >(Kernel[0 ], 12 , offb_i);
732
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
733
+ }
734
+ // call last GEMM
735
+ // the last GEMM's K might be smaller than small_K
736
+ unsigned int residue_K = args.K % small_K;
737
+ if (residue_K == 0 )
738
+ residue_K = small_K;
739
+ unsigned offa_i = args.lda * (args.K / num_of_gemm) * (num_of_gemm - 1 ) + args.offA ;
740
+ unsigned offb_i = args.ldb * (args.K / num_of_gemm) * (num_of_gemm - 1 ) + args.offB ;
741
+ setKernelArg<int >(Kernel[0 ], 5 , residue_K);
742
+ setKernelArg<int >(Kernel[0 ], 11 , offa_i);
743
+ setKernelArg<int >(Kernel[0 ], 12 , offb_i);
744
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , args.events );
745
+ return error;
746
+ }
747
+ }
748
+
749
+ if (args.M % 96 != 0 && args.N % 96 != 0 && args.M >= 96 && args.N >= 96 )
750
+ {
751
+ if (VERB) printf (" ===> EXECUTE KERNEL 0, 1, 2, 3 \n " );
752
+
753
+ if (args.K > K_block_size)
754
+ {
755
+ int num_of_gemm = ((args.K - 1 ) / K_block_size) + 1 ;
756
+
757
+ // first 4 GEMMs
758
+ unsigned int small_K = K_block_size;
759
+ setKernelArg<int >(Kernel[0 ], 5 , small_K);
760
+
761
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantSplit->ls , args.numEventsInWaitList , args.eventWaitList , NULL );
762
+
763
+ gs[0 ] = 16 ;
764
+ error |= clEnqueueNDRangeKernel (queue, Kernel[1 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
765
+
766
+ gs[1 ] = 16 ;
767
+ gs[0 ] = GlobalX;
768
+ error |= clEnqueueNDRangeKernel (queue, Kernel[2 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
769
+
770
+ gs[0 ] = 16 ; gs[1 ] = 16 ;
771
+ error |= clEnqueueNDRangeKernel (queue, Kernel[3 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
772
+
773
+ // middle GEMMs
774
+ unsigned beta_one = 1 .0f ;
775
+ setKernelArg<int >(Kernel[0 ], 7 , beta_one);
776
+ for (int i = 1 ; i < num_of_gemm - 1 ; i++)
777
+ {
778
+ unsigned offa_i = args.lda * (args.K / num_of_gemm) * i + args.offA ;
779
+ unsigned offb_i = args.ldb * (args.K / num_of_gemm) * i + args.offB ;
780
+ setKernelArg<int >(Kernel[0 ], 11 , offa_i);
781
+ setKernelArg<int >(Kernel[0 ], 12 , offb_i);
782
+ // gs[2] = {GlobalX, GlobalY};
783
+ gs[0 ] = GlobalX;
784
+ gs[1 ] = GlobalY;
785
+
786
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
787
+
788
+ gs[0 ] = 16 ;
789
+ error |= clEnqueueNDRangeKernel (queue, Kernel[1 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
790
+
791
+ gs[1 ] = 16 ;
792
+ gs[0 ] = GlobalX;
793
+ error |= clEnqueueNDRangeKernel (queue, Kernel[2 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
794
+
795
+ gs[0 ] = 16 ; gs[1 ] = 16 ;
796
+ error |= clEnqueueNDRangeKernel (queue, Kernel[3 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
797
+ }
798
+ // last 4 GEMMs
799
+ unsigned int residue_K = args.K % small_K;
800
+ if (residue_K == 0 )
801
+ residue_K = small_K;
802
+ unsigned offa_i = args.lda * (args.K / num_of_gemm) * (num_of_gemm - 1 ) + args.offA ;
803
+ unsigned offb_i = args.ldb * (args.K / num_of_gemm) * (num_of_gemm - 1 ) + args.offB ;
804
+ setKernelArg<int >(Kernel[0 ], 5 , residue_K);
805
+ setKernelArg<int >(Kernel[0 ], 11 , offa_i);
806
+ setKernelArg<int >(Kernel[0 ], 12 , offb_i);
807
+
808
+ gs[0 ] = GlobalX;
809
+ gs[1 ] = GlobalY;
810
+
811
+ error = clEnqueueNDRangeKernel (queue, Kernel[0 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
812
+
813
+ gs[0 ] = 16 ;
814
+ error |= clEnqueueNDRangeKernel (queue, Kernel[1 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
815
+
816
+ gs[1 ] = 16 ;
817
+ gs[0 ] = GlobalX;
818
+ error |= clEnqueueNDRangeKernel (queue, Kernel[2 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , NULL );
819
+
820
+ gs[0 ] = 16 ; gs[1 ] = 16 ;
821
+ error |= clEnqueueNDRangeKernel (queue, Kernel[3 ], 2 , NULL , gs, m_variantSplit->ls , 0 , NULL , args.events );
822
+
823
+
824
+ return error;
825
+ }
826
+ }
827
+
828
+
829
+ }
830
+ }
831
+ }
832
+
686
833
687
834
if (args.M %96 ==0 && args.N %96 ==0 )
688
835
{
0 commit comments