Skip to content

Commit d3d36e0

Browse files
author
Timmy
committed
fix sgemm NT perf drop when lda=ldb=6144 and k>1536
1 parent 63ca259 commit d3d36e0

File tree

5 files changed

+417
-0
lines changed

5 files changed

+417
-0
lines changed

src/library/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ set (SRC_CL_TEMPLATES
234234
dtrsm_gpu192.cl
235235
dgemm_gcn_SmallMatrices.cl
236236
sgemm_gcn_SmallMatrices.cl
237+
sgemm_gcn_bigMatrices.cl
237238
sgemm_gcn.cl
238239
zgemm_gcn.cl
239240
)
@@ -253,6 +254,9 @@ set(SRC_CL_TEMPLATES_GEN
253254
sgemm_gcn_SmallMatrices.clHawaii_64.bin.cl
254255
sgemm_gcn_SmallMatrices.clTahiti_64.bin.cl
255256
sgemm_gcn_SmallMatrices.clBonaire_64.bin.cl
257+
sgemm_gcn_bigMatrices.clHawaii_64.bin.cl
258+
sgemm_gcn_bigMatrices.clTahiti_64.bin.cl
259+
sgemm_gcn_bigMatrices.clBonaire_64.bin.cl
256260
sgemm_gcn.clHawaii_64.bin.cl
257261
zgemm_gcn.clHawaii_64.bin.cl
258262
sgemm_gcn.clBonaire_64.bin.cl

src/library/bingen.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ ${CLTEMPLATE_PATH}/sgemm_hawaiiSplitKernel.cl
1515
${CLTEMPLATE_PATH}/sgemm_gcn.cl
1616
${CLTEMPLATE_PATH}/zgemm_gcn.cl
1717
${CLTEMPLATE_PATH}/sgemm_gcn_SmallMatrices.cl
18+
${CLTEMPLATE_PATH}/sgemm_gcn_bigMatrices.cl
1819
${CLTEMPLATE_PATH}/sgemm_hawaiiSplit64_32.cl
1920
${CLTEMPLATE_PATH}/dtrsm_gpu192.cl
2021
)

src/library/blas/functor/hawaii_sgemmSplitKernel.cc

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,153 @@ cl_int clBlashawaiiSgemmSplitKernelFunctor::KernelsLaunch(cl_command_queue queue
683683

684684
std::size_t gs[2] = {GlobalX, GlobalY};
685685
cl_int error = 0;
686+
687+
//deals with square matrix sizes where K is mod 16 for now
688+
if (args.lda == args.ldb)
689+
{
690+
if ((args.K % 16 == 0) && (args.lda >= 6144) && (args.ldb >= 6144))
691+
{
692+
if ((args.lda % 1024 == 0) && (args.ldb % 1024 == 0) && (args.transA == clblasNoTrans) && (args.transB == clblasTrans))
693+
{
694+
//handles special cases where a direct call to "sgemm_NT_96_96_16..." causes perf drop due to cache miss/thrashing
695+
//this special cases is: sgemm column major NT / sgemm row major TN; lda and ldb are big multiples of 1024 such as 4096 and 6144
696+
//K is bigger than a threshold: 1536 for lda=ldb=6144
697+
698+
//
699+
int K_block_size;
700+
if (args.lda == 6144)
701+
{
702+
K_block_size = 1536;
703+
}
704+
else
705+
{
706+
K_block_size = 128;
707+
}
708+
709+
if (args.M % 96 == 0 && args.N % 96 == 0)
710+
{
711+
if (VERB) printf(" ===> EXECUTE KERNEL 0 \n");
712+
if (args.K > K_block_size)
713+
{
714+
//split into many GEMM calls with K = K_block_size
715+
//there are at least 2 GEMM calls
716+
int num_of_gemm = ((args.K - 1) / K_block_size) + 1;
717+
718+
//call first GEMM
719+
unsigned int small_K = K_block_size;
720+
setKernelArg<int>(Kernel[0], 5, small_K);
721+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL, gs, m_variantSplit->ls, args.numEventsInWaitList, args.eventWaitList, NULL);
722+
723+
//call middle GEMMs
724+
unsigned beta_one = 1.0f;
725+
setKernelArg<int>(Kernel[0], 7, beta_one);
726+
for (int i = 1; i < num_of_gemm - 1; i++)
727+
{
728+
unsigned offa_i = args.lda * (args.K / num_of_gemm) * i + args.offA;
729+
unsigned offb_i = args.ldb * (args.K / num_of_gemm) * i + args.offB;
730+
setKernelArg<int>(Kernel[0], 11, offa_i);
731+
setKernelArg<int>(Kernel[0], 12, offb_i);
732+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL, gs, m_variantSplit->ls, 0, NULL, NULL);
733+
}
734+
//call last GEMM
735+
//the last GEMM's K might be smaller than small_K
736+
unsigned int residue_K = args.K % small_K;
737+
if (residue_K == 0)
738+
residue_K = small_K;
739+
unsigned offa_i = args.lda * (args.K / num_of_gemm) * (num_of_gemm - 1) + args.offA;
740+
unsigned offb_i = args.ldb * (args.K / num_of_gemm) * (num_of_gemm - 1) + args.offB;
741+
setKernelArg<int>(Kernel[0], 5, residue_K);
742+
setKernelArg<int>(Kernel[0], 11, offa_i);
743+
setKernelArg<int>(Kernel[0], 12, offb_i);
744+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL, gs, m_variantSplit->ls, 0, NULL, args.events);
745+
return error;
746+
}
747+
}
748+
749+
if (args.M % 96 != 0 && args.N % 96 != 0 && args.M >= 96 && args.N >= 96)
750+
{
751+
if (VERB) printf(" ===> EXECUTE KERNEL 0, 1, 2, 3 \n");
752+
753+
if (args.K > K_block_size)
754+
{
755+
int num_of_gemm = ((args.K - 1) / K_block_size) + 1;
756+
757+
//first 4 GEMMs
758+
unsigned int small_K = K_block_size;
759+
setKernelArg<int>(Kernel[0], 5, small_K);
760+
761+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL, gs, m_variantSplit->ls, args.numEventsInWaitList, args.eventWaitList, NULL);
762+
763+
gs[0] = 16;
764+
error |= clEnqueueNDRangeKernel(queue, Kernel[1], 2, NULL, gs, m_variantSplit->ls, 0, NULL, NULL);
765+
766+
gs[1] = 16;
767+
gs[0] = GlobalX;
768+
error |= clEnqueueNDRangeKernel(queue, Kernel[2], 2, NULL, gs, m_variantSplit->ls, 0, NULL, NULL);
769+
770+
gs[0] = 16; gs[1] = 16;
771+
error |= clEnqueueNDRangeKernel(queue, Kernel[3], 2, NULL, gs, m_variantSplit->ls, 0, NULL, NULL);
772+
773+
//middle GEMMs
774+
unsigned beta_one = 1.0f;
775+
setKernelArg<int>(Kernel[0], 7, beta_one);
776+
for (int i = 1; i < num_of_gemm - 1; i++)
777+
{
778+
unsigned offa_i = args.lda * (args.K / num_of_gemm) * i + args.offA;
779+
unsigned offb_i = args.ldb * (args.K / num_of_gemm) * i + args.offB;
780+
setKernelArg<int>(Kernel[0], 11, offa_i);
781+
setKernelArg<int>(Kernel[0], 12, offb_i);
782+
//gs[2] = {GlobalX, GlobalY};
783+
gs[0] = GlobalX;
784+
gs[1] = GlobalY;
785+
786+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL, gs, m_variantSplit->ls, 0, NULL, NULL);
787+
788+
gs[0] = 16;
789+
error |= clEnqueueNDRangeKernel(queue, Kernel[1], 2, NULL, gs, m_variantSplit->ls, 0, NULL, NULL);
790+
791+
gs[1] = 16;
792+
gs[0] = GlobalX;
793+
error |= clEnqueueNDRangeKernel(queue, Kernel[2], 2, NULL, gs, m_variantSplit->ls, 0, NULL, NULL);
794+
795+
gs[0] = 16; gs[1] = 16;
796+
error |= clEnqueueNDRangeKernel(queue, Kernel[3], 2, NULL, gs, m_variantSplit->ls, 0, NULL, NULL);
797+
}
798+
//last 4 GEMMs
799+
unsigned int residue_K = args.K % small_K;
800+
if (residue_K == 0)
801+
residue_K = small_K;
802+
unsigned offa_i = args.lda * (args.K / num_of_gemm) * (num_of_gemm - 1) + args.offA;
803+
unsigned offb_i = args.ldb * (args.K / num_of_gemm) * (num_of_gemm - 1) + args.offB;
804+
setKernelArg<int>(Kernel[0], 5, residue_K);
805+
setKernelArg<int>(Kernel[0], 11, offa_i);
806+
setKernelArg<int>(Kernel[0], 12, offb_i);
807+
808+
gs[0] = GlobalX;
809+
gs[1] = GlobalY;
810+
811+
error = clEnqueueNDRangeKernel(queue, Kernel[0], 2, NULL, gs, m_variantSplit->ls, 0, NULL, NULL);
812+
813+
gs[0] = 16;
814+
error |= clEnqueueNDRangeKernel(queue, Kernel[1], 2, NULL, gs, m_variantSplit->ls, 0, NULL, NULL);
815+
816+
gs[1] = 16;
817+
gs[0] = GlobalX;
818+
error |= clEnqueueNDRangeKernel(queue, Kernel[2], 2, NULL, gs, m_variantSplit->ls, 0, NULL, NULL);
819+
820+
gs[0] = 16; gs[1] = 16;
821+
error |= clEnqueueNDRangeKernel(queue, Kernel[3], 2, NULL, gs, m_variantSplit->ls, 0, NULL, args.events);
822+
823+
824+
return error;
825+
}
826+
}
827+
828+
829+
}
830+
}
831+
}
832+
686833

687834
if (args.M%96==0 && args.N%96==0)
688835
{

0 commit comments

Comments
 (0)