Skip to content

Commit 5137231

Browse files
author
Timmy
committed
Merge pull request #90 from TimmyLiu/develop
improve small sgemm column major TN on Hawaii
2 parents fdcf987 + a280c96 commit 5137231

File tree

4 files changed

+195
-23
lines changed

4 files changed

+195
-23
lines changed

src/library/blas/functor/hawaii.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,16 @@ clblasSgemmFunctor * FunctorSelectorHawaii::select_sgemm_specific(clblasSgemmFun
101101
//TODO: the logic below is complicated; Needs cleanup;
102102
clblasSgemmFunctor * functor;
103103
bool Not_TT = ((args.transA==clblasNoTrans && args.transB==clblasTrans ) || ( args.transA==clblasNoTrans && args.transB==clblasNoTrans ) || ( args.transA==clblasTrans && args.transB==clblasNoTrans ));
104-
bool SmallMatrices = args.M/6*args.N/6<150*150 || ((args.M%64!=0 && args.N%64!=0 && args.M<1900 &&args.N<1900 ) && (args.M%96!=0 && args.N%96!=0 && args.M<1900 &&args.N<1900 ));
104+
bool SmallMatrices = args.M/6*args.N/6<180*180 || ((args.M%64!=0 && args.N%64!=0 && args.M<1900 &&args.N<1900 ) && (args.M%96!=0 && args.N%96!=0 && args.M<1900 &&args.N<1900 ));
105105
bool SmallMatricesMod32= (SmallMatrices && (args.M%32==0&&args.N%32==0)) ;
106106
SmallMatricesMod32 = SmallMatricesMod32&&Not_TT&&args.K % 16 == 0;
107107
//SmallMatrices= false;
108108

109109
bool useSpliKernel=((args.M%96==0 && args.N%96==0) || !(args.M%64==0 && args.N%64==0&& args.M<4000 &&args.N<4000)) /*&&args.K%16==0*/;
110110
useSpliKernel=useSpliKernel&&Not_TT;
111111

112-
//the English translation of below is: if small matrix that is (not mod32) and (NT or NN) and K has to be mod 16
113-
if (SmallMatrices && (!SmallMatricesMod32) && (args.transA == clblasNoTrans) && (args.K%16 == 0))
112+
//the English translation of below is: if small matrix that is (not mod32) and (not_TT) and K has to be mod 16
113+
if (SmallMatrices && (!SmallMatricesMod32) && (Not_TT) && (args.K%16 == 0))
114114
{
115115
functor = clBlashawaiiSgemmBranchKernelFunctor::provide(args, "Hawaii");
116116
if (functor)

src/library/blas/functor/hawaii_sgemmBranchKernel.cc

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ static const Variant * select_variant_BranchKernel(clblasSgemmFunctor::Args & ar
123123

124124
// ===== sgemm NN ======
125125
// sgemm_NN_32_32_16_16x16_2x2__ALPHABETA_BRANCH
126-
const char* KName_NT = SGEMM_KERNEL_NAME(N, N, 32, 32, 16, 16, 16, 2, 2, __ALPHABETA, BRANCH);
126+
const char* KName_NN = SGEMM_KERNEL_NAME(N, N, 32, 32, 16, 16, 16, 2, 2, __ALPHABETA, BRANCH);
127127
const char* KBin_NN64;
128128
size_t KBin_NNSize64 = 0;
129129
#if BUILD_KERNEL_FROM_STRING
@@ -132,14 +132,14 @@ static const Variant * select_variant_BranchKernel(clblasSgemmFunctor::Args & ar
132132
#else
133133
if (!strcmp(DevName, "Hawaii"))
134134
{
135-
//KBin_NT64 = SGEMM_SRC_NAME_BIN(N, T, 16, __ALPHABETA, 64, HAWAII) ;
135+
136136
KBin_NN64 = sgemm_NN_32_32_16_16x16_2x2__ALPHABETA_BRANCH_64_bin_Hawaii;
137137
KBin_NNSize64 = sizeof(sgemm_NN_32_32_16_16x16_2x2__ALPHABETA_BRANCH_64_bin_Hawaii);
138138

139139
}
140140
#endif
141141
static const Variant variant = SGEMM_VARIANT_OBJ(N, N, 16, 16, 16, 2, 2, 64, __ALPHABETA,
142-
KName_NT,
142+
KName_NN,
143143
NULL,
144144
NULL,
145145
KBin_NN64,
@@ -188,21 +188,49 @@ static const Variant * select_variant_BranchKernel(clblasSgemmFunctor::Args & ar
188188

189189
return &variant;
190190
}
191-
else
191+
}
192+
else
193+
{
194+
if (args.transB == clblasNoTrans)
192195
{
193-
if (args.transB == clblasNoTrans)
194-
{
195196

196197
// ===== sgemm TN ======
197-
// currently not supported
198+
//sgemm_TN_32_32_16_16x16_2x2__ALPHABETA_BRANCH
199+
const char* KName_TN = SGEMM_KERNEL_NAME(T, N, 32, 32, 16, 16, 16, 2, 2, __ALPHABETA, BRANCH);
200+
201+
202+
const char* KBin_TN64;
203+
size_t KBin_TNSize64 = 0;
204+
205+
206+
#if BUILD_KERNEL_FROM_STRING
207+
//currently not supported
198208
return NULL;
199-
}
200-
}
209+
#else
210+
if (!strcmp(DevName, "Hawaii"))
211+
{
212+
KBin_TN64 = sgemm_TN_32_32_16_16x16_2x2__ALPHABETA_BRANCH_64_bin_Hawaii;
213+
KBin_TNSize64 = sizeof(sgemm_TN_32_32_16_16x16_2x2__ALPHABETA_BRANCH_64_bin_Hawaii);
201214

215+
}
216+
#endif
217+
// ===== SGEMM NT ======
218+
static const Variant variant = SGEMM_VARIANT_OBJ(T, N, 16, 16, 16, 2, 2, 64, __ALPHABETA,
219+
KName_TN,
220+
NULL,
221+
NULL,
222+
KBin_TN64,
223+
KBin_TNSize64);
224+
225+
return &variant;
226+
}
202227
return NULL;
203228
}
229+
230+
return NULL;
204231
}
205232

233+
206234
clBlashawaiiSgemmBranchKernelFunctor::clBlashawaiiSgemmBranchKernelFunctor(Args & args, const Variant * variant, cl_int & err)
207235
{
208236

src/library/blas/gens/clTemplates/sgemm_gcn_SmallMatrices.cl

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,3 +910,127 @@ __kernel void sgemm_TN_32_32_16_16x16_2x2__ALPHA( __global float const * restric
910910

911911
";
912912

913+
static const char * sgemm_TN_32_32_16_16x16_2x2__ALPHABETA_BRANCH = "
914+
915+
#define M2x2 \
916+
rA[0][0] = lA[offA + 0]; \
917+
rA[0][1] = lA[offA + 16]; \
918+
rB[0][0] = lB[offB + 0]; \
919+
rB[0][1] = lB[offB + 16]; \
920+
offA += 33; \
921+
offB += 33; \
922+
rC[0][0]=mad(rA[0][0],rB[0][0],rC[0][0]); \
923+
rC[1][0]=mad(rA[0][1],rB[0][0],rC[1][0]); \
924+
rC[0][1]=mad(rA[0][0],rB[0][1],rC[0][1]); \
925+
rC[1][1]=mad(rA[0][1],rB[0][1],rC[1][1]); \
926+
mem_fence(CLK_LOCAL_MEM_FENCE);
927+
928+
__attribute__((reqd_work_group_size(16,16,1)))
929+
930+
__kernel void sgemm_TN_32_32_16_16x16_2x2__ALPHABETA_BRANCH( __global float const * restrict A,
931+
__global float const * restrict B,
932+
__global float * C,
933+
uint const M,
934+
uint const N,
935+
uint const K,
936+
float const alpha,
937+
float const beta,
938+
uint lda,
939+
uint ldb,
940+
uint ldc,
941+
uint offsetA,
942+
uint offsetB,
943+
uint offsetC)
944+
{
945+
float rC[2][2] = {(float)0};
946+
float rA[1][2];
947+
float rB[1][2];
948+
949+
950+
A += offsetA;
951+
B += offsetB;
952+
C+=offsetC;
953+
954+
__local float lA[528];//16*32+16
955+
__local float lB[528];
956+
957+
uint gidx = get_group_id(0);
958+
uint gidy = get_group_id(1);
959+
uint idx = get_local_id(0);
960+
uint idy = get_local_id(1);
961+
962+
int CurrentOffSetA = gidx*32+ idy;
963+
int CurrentOffSetB = gidy*32+ idy;
964+
965+
A += (gidx*32+idy)*lda + idx;
966+
B += (gidy*32+idy)*ldb + idx;
967+
968+
969+
uint block_k = K >> 4;
970+
do
971+
{
972+
__local float* plA = lA + idx*33+idy;
973+
__local float* plB = lB + idx*33+idy;
974+
barrier(CLK_LOCAL_MEM_FENCE);
975+
976+
plB[0] = CurrentOffSetB>=N?0.0:B[0];
977+
plB[16] = CurrentOffSetB+16>=N?0.0:B[16*ldb];
978+
979+
plA[0] = CurrentOffSetA>=M?0.0:A[0];
980+
plA[16] = CurrentOffSetA+16>=M?0.0:A[16*lda];
981+
982+
983+
barrier(CLK_LOCAL_MEM_FENCE);
984+
uint offA = idx;
985+
uint offB = idy;
986+
987+
988+
M2x2
989+
M2x2
990+
M2x2
991+
M2x2
992+
M2x2
993+
M2x2
994+
M2x2
995+
M2x2
996+
M2x2
997+
M2x2
998+
M2x2
999+
M2x2
1000+
M2x2
1001+
M2x2
1002+
M2x2
1003+
M2x2
1004+
1005+
A += 16;
1006+
B += 16;
1007+
} while (--block_k > 0);
1008+
1009+
1010+
int offset_x = gidx*32+idx;
1011+
int offset_y = gidy*32+ idy;
1012+
1013+
if(offset_x>=M || offset_y>=N )
1014+
return;
1015+
1016+
C+=offset_x+offset_y*ldc;
1017+
1018+
int i = 0;
1019+
do
1020+
{
1021+
C[0 ] = mad(alpha, rC[i][0], beta*C[0]);
1022+
if(offset_y+16<N)
1023+
C[16*ldc] = mad(alpha, rC[i][1], beta*C[16*ldc]);
1024+
1025+
C+=16;
1026+
offset_x+=16;
1027+
if(offset_x>=M )
1028+
return;
1029+
1030+
1031+
}
1032+
while (++i < 2);
1033+
}
1034+
1035+
";
1036+

0 commit comments

Comments
 (0)