Skip to content

Commit 39b324d

Browse files
author
Timmy
committed
improve big sgemm column NN perf. improve small sgemm NN perf.
1 parent fda48a7 commit 39b324d

File tree

4 files changed

+153
-5
lines changed

4 files changed

+153
-5
lines changed

src/library/blas/functor/hawaii.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,16 @@ clblasSgemmFunctor * FunctorSelectorHawaii::select_sgemm_specific(clblasSgemmFun
101101
//TODO: the logic below is complicated; Needs cleanup;
102102
clblasSgemmFunctor * functor;
103103
bool Not_TT = ((args.transA==clblasNoTrans && args.transB==clblasTrans ) || ( args.transA==clblasNoTrans && args.transB==clblasNoTrans ) || ( args.transA==clblasTrans && args.transB==clblasNoTrans ));
104-
bool SmallMatrices = args.M/6*args.N/6<200*200 || ((args.M%64!=0 && args.N%64!=0 && args.M<1900 &&args.N<1900 ) && (args.M%96!=0 && args.N%96!=0 && args.M<1900 &&args.N<1900 ));
104+
bool SmallMatrices = args.M/6*args.N/6<150*150 || ((args.M%64!=0 && args.N%64!=0 && args.M<1900 &&args.N<1900 ) && (args.M%96!=0 && args.N%96!=0 && args.M<1900 &&args.N<1900 ));
105105
bool SmallMatricesMod32= (SmallMatrices && (args.M%32==0&&args.N%32==0)) ;
106106
SmallMatricesMod32 = SmallMatricesMod32&&Not_TT&&args.K % 16 == 0;
107107
//SmallMatrices= false;
108108

109109
bool useSpliKernel=((args.M%96==0 && args.N%96==0) || !(args.M%64==0 && args.N%64==0&& args.M<4000 &&args.N<4000)) /*&&args.K%16==0*/;
110110
useSpliKernel=useSpliKernel&&Not_TT;
111111

112-
//the English translation of below is: if small matrix that is not mod32 and NT and K has to be mod 16
113-
if (SmallMatrices && (!SmallMatricesMod32) && (args.transA == clblasNoTrans && args.transB == clblasTrans) && (args.K%16 == 0))
112+
//the English translation of below is: if small matrix that is (not mod32) and (NT or NN) and K has to be mod 16
113+
if (SmallMatrices && (!SmallMatricesMod32) && (args.transA == clblasNoTrans) && (args.K%16 == 0))
114114
{
115115
functor = clBlashawaiiSgemmBranchKernelFunctor::provide(args, "Hawaii");
116116
if (functor)

src/library/blas/functor/hawaii_sgemmBranchKernel.cc

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,30 @@ static const Variant * select_variant_BranchKernel(clblasSgemmFunctor::Args & ar
122122
{
123123

124124
// ===== sgemm NN ======
125-
// currently not supported
125+
// sgemm_NN_32_32_16_16x16_2x2__ALPHABETA_BRANCH
126+
const char* KName_NT = SGEMM_KERNEL_NAME(N, N, 32, 32, 16, 16, 16, 2, 2, __ALPHABETA, BRANCH);
127+
const char* KBin_NN64;
128+
size_t KBin_NNSize64 = 0;
129+
#if BUILD_KERNEL_FROM_STRING
130+
//currently not supported
126131
return NULL;
132+
#else
133+
if (!strcmp(DevName, "Hawaii"))
134+
{
135+
//KBin_NT64 = SGEMM_SRC_NAME_BIN(N, T, 16, __ALPHABETA, 64, HAWAII) ;
136+
KBin_NN64 = sgemm_NN_32_32_16_16x16_2x2__ALPHABETA_BRANCH_64_bin_Hawaii;
137+
KBin_NNSize64 = sizeof(sgemm_NN_32_32_16_16x16_2x2__ALPHABETA_BRANCH_64_bin_Hawaii);
138+
139+
}
140+
#endif
141+
static const Variant variant = SGEMM_VARIANT_OBJ(N, N, 16, 16, 16, 2, 2, 64, __ALPHABETA,
142+
KName_NT,
143+
NULL,
144+
NULL,
145+
KBin_NN64,
146+
KBin_NNSize64);
147+
148+
return &variant;
127149
}
128150
if (args.transB == clblasTrans)
129151
{

src/library/blas/gens/clTemplates/sgemm_gcn_SmallMatrices.cl

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,131 @@ __kernel void sgemm_NN_32_32_16_16x16_2x2__ALPHA( __global float const * restric
567567

568568
";
569569

570+
static const char * sgemm_NN_32_32_16_16x16_2x2__ALPHABETA_BRANCH = "
571+
572+
#define M2x2 \
573+
rA[0][0] = lA[offA + 0]; \
574+
rA[0][1] = lA[offA + 16]; \
575+
rB[0][0] = lB[offB + 0]; \
576+
rB[0][1] = lB[offB + 16]; \
577+
offA += 33; \
578+
offB += 33; \
579+
rC[0][0]=mad(rA[0][0],rB[0][0],rC[0][0]); \
580+
rC[1][0]=mad(rA[0][1],rB[0][0],rC[1][0]); \
581+
rC[0][1]=mad(rA[0][0],rB[0][1],rC[0][1]); \
582+
rC[1][1]=mad(rA[0][1],rB[0][1],rC[1][1]); \
583+
mem_fence(CLK_LOCAL_MEM_FENCE);
584+
585+
__attribute__((reqd_work_group_size(16,16,1)))
586+
587+
__kernel void sgemm_NN_32_32_16_16x16_2x2__ALPHABETA_BRANCH( __global float const * restrict A,
588+
__global float const * restrict B,
589+
__global float * C,
590+
uint const M,
591+
uint const N,
592+
uint const K,
593+
float const alpha,
594+
float const beta,
595+
uint lda,
596+
uint ldb,
597+
uint ldc,
598+
uint offsetA,
599+
uint offsetB,
600+
uint offsetC)
601+
{
602+
float rC[2][2] = {(float)0};
603+
float rA[1][2];
604+
float rB[1][2];
605+
606+
607+
608+
A += offsetA;
609+
B += offsetB;
610+
C+=offsetC;
611+
612+
__local float lA[528];//16*32+16
613+
__local float lB[528];
614+
615+
uint gidx = get_group_id(0);
616+
uint gidy = get_group_id(1);
617+
uint idx = get_local_id(0);
618+
uint idy = get_local_id(1);
619+
620+
int CurrentOffSetA = gidx*32+ idx;
621+
int CurrentOffSetB = gidy*32+ idy;
622+
623+
A += gidx*32+ idx + idy*lda;
624+
B += gidy*32*ldb+ idx + idy*ldb;
625+
626+
627+
uint block_k = K >> 4;
628+
do
629+
{
630+
__local float* plA = lA + idy*33+idx;
631+
__local float* plB = lB + idx*33+idy;
632+
barrier(CLK_LOCAL_MEM_FENCE);
633+
634+
plB[0] = CurrentOffSetB>=N?0.0:B[0];
635+
plB[16] = CurrentOffSetB+16>=N?0.0:B[16*ldb];
636+
637+
plA[0] = CurrentOffSetA>=M?0.0:A[0];
638+
plA[16] = CurrentOffSetA+16>=M?0.0:A[16];
639+
640+
641+
barrier(CLK_LOCAL_MEM_FENCE);
642+
uint offA = idx;
643+
uint offB = idy;
644+
645+
M2x2
646+
M2x2
647+
M2x2
648+
M2x2
649+
M2x2
650+
M2x2
651+
M2x2
652+
M2x2
653+
M2x2
654+
M2x2
655+
M2x2
656+
M2x2
657+
M2x2
658+
M2x2
659+
M2x2
660+
M2x2
661+
662+
A += lda<<4;
663+
B += 16;
664+
//}
665+
} while (--block_k > 0);
666+
667+
int offset_x = gidx*32+idx;
668+
int offset_y = gidy*32+ idy;
669+
if(offset_x>=M || offset_y>=N )
670+
return;
671+
672+
C+=offset_x+offset_y*ldc;
673+
674+
675+
int i = 0;
676+
do
677+
{
678+
C[0 ] = mad(alpha, rC[i][0], beta*C[0]);
679+
if(offset_y+16<N)
680+
C[16*ldc] = mad(alpha, rC[i][1], beta*C[16*ldc]);
681+
682+
C+=16;
683+
offset_x+=16;
684+
if(offset_x>=M )
685+
return;
686+
687+
688+
}
689+
while (++i < 2);
690+
691+
}
692+
693+
";
694+
570695
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
571696
static const char * sgemm_TN_32_32_16_16x16_2x2__ALPHABETA = "
572697

src/library/blas/gens/clTemplates/sgemm_hawaiiSplitKernel.cl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2546,7 +2546,7 @@ static const char * sgemm_NN_16_SPLIT__ALPHABETA = "
25462546
rC[3][5]=mad(rA[0][3],rB[0][5],rC[3][5]); \
25472547
rC[4][5]=mad(rA[0][4],rB[0][5],rC[4][5]); \
25482548
rC[5][5]=mad(rA[0][5],rB[0][5],rC[5][5]); \
2549-
barrier(CLK_LOCAL_MEM_FENCE);
2549+
mem_fence(CLK_LOCAL_MEM_FENCE);
25502550

25512551
__attribute__((reqd_work_group_size(16,16,1)))
25522552
__kernel void sgemm_NN_96_96_16_16x16_6x6__ALPHABETA_SPLIT_MAIN( __global float const * restrict A,
@@ -2591,6 +2591,7 @@ __kernel void sgemm_NN_96_96_16_16x16_6x6__ALPHABETA_SPLIT_MAIN( __global float
25912591
{
25922592
__local float* plA = lA + idy*97+idx;
25932593
__local float* plB = lB + idx*97+idy;
2594+
barrier(CLK_LOCAL_MEM_FENCE);
25942595
plB[0] = B[0];
25952596
plB[16] = B[16*ldb];
25962597
plB[32] = B[32*ldb];

0 commit comments

Comments
 (0)