Skip to content

Commit 9c66a77

Browse files
pavankyshehzan10
authored andcommitted
Fixing issues for when Beta == 0 in sgemm special cases
1 parent 627c654 commit 9c66a77

6 files changed

+162
-105
lines changed

src/library/blas/AutoGemm/UserGemmKernelSources/sgemm_Col_NN_B1_MX032_NX032_KX16_BRANCH_src.cpp

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,46 +57,46 @@ __kernel void sgemm_Col_NN_B1_MX032_NX032_KX16_BRANCH (
5757
float rC[2][2] = { {(float)0} };
5858
float rA[1][2];
5959
float rB[1][2];
60-
6160

62-
61+
62+
6363
A += offsetA;
6464
B += offsetB;
6565
C+=offsetC;
66-
66+
6767
__local float lA[528];//16*32+16
6868
__local float lB[528];
69-
69+
7070
uint gidx = get_group_id(0);
7171
uint gidy = get_group_id(1);
7272
uint idx = get_local_id(0);
7373
uint idy = get_local_id(1);
74-
74+
7575
int CurrentOffSetA = gidx*32+ idx;
7676
int CurrentOffSetB = gidy*32+ idy;
7777

7878
A += gidx*32+ idx + idy*lda;
7979
B += gidy*32*ldb+ idx + idy*ldb;
80-
81-
80+
81+
8282
uint block_k = K >> 4;
83-
do
83+
do
8484
{
8585
__local float* plA = lA + idy*33+idx;
8686
__local float* plB = lB + idx*33+idy;
8787
barrier(CLK_LOCAL_MEM_FENCE);
88-
88+
8989
plB[0] = CurrentOffSetB>=N?0.0:B[0];
9090
plB[16] = CurrentOffSetB+16>=N?0.0:B[16*ldb];
91-
91+
9292
plA[0] = CurrentOffSetA>=M?0.0:A[0];
9393
plA[16] = CurrentOffSetA+16>=M?0.0:A[16];
9494

95-
95+
9696
barrier(CLK_LOCAL_MEM_FENCE);
9797
uint offA = idx;
9898
uint offB = idy;
99-
99+
100100
M2x2
101101
M2x2
102102
M2x2
@@ -123,26 +123,36 @@ __kernel void sgemm_Col_NN_B1_MX032_NX032_KX16_BRANCH (
123123
int offset_y = gidy*32+ idy;
124124
if(offset_x>=M || offset_y>=N )
125125
return;
126-
126+
127127
C+=offset_x+offset_y*ldc;
128-
129-
128+
129+
130130
int i = 0;
131-
do
132-
{
133-
C[0 ] = mad(alpha, rC[i][0], beta*C[0]);
134-
if(offset_y+16<N)
135-
C[16*ldc] = mad(alpha, rC[i][1], beta*C[16*ldc]);
136-
137-
C+=16;
138-
offset_x+=16;
139-
if(offset_x>=M )
140-
return;
141-
142-
143-
}
144-
while (++i < 2);
145-
131+
if (beta != 0) {
132+
do
133+
{
134+
C[0 ] = mad(alpha, rC[i][0], beta*C[0]);
135+
if(offset_y+16<N)
136+
C[16*ldc] = mad(alpha, rC[i][1], beta*C[16*ldc]);
137+
C+=16;
138+
offset_x+=16;
139+
if(offset_x>=M )
140+
return;
141+
}
142+
while (++i < 2);
143+
} else {
144+
do
145+
{
146+
C[0 ] = alpha * rC[i][0];
147+
if(offset_y+16<N)
148+
C[16*ldc] = alpha * rC[i][1];
149+
C+=16;
150+
offset_x+=16;
151+
if(offset_x>=M )
152+
return;
153+
}
154+
while (++i < 2);
155+
}
146156
}
147157

148158
);

src/library/blas/AutoGemm/UserGemmKernelSources/sgemm_Col_NT_B1_MX032_NX032_KX16_BRANCH_src.cpp

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -57,41 +57,41 @@ __kernel void sgemm_Col_NT_B1_MX032_NX032_KX16_BRANCH (
5757
float rC[2][2] = { {(float)0} };
5858
float rA[1][2];
5959
float rB[1][2];
60-
61-
60+
61+
6262
A += offsetA;
6363
B += offsetB;
6464
C+=offsetC;
65-
65+
6666
__local float lA[528];//16*32+16
6767
__local float lB[528];
68-
68+
6969
uint gidx = get_group_id(0);
7070
uint gidy = get_group_id(1);
7171
uint idx = get_local_id(0);
7272
uint idy = get_local_id(1);
73-
73+
7474
int CurrentOffSetA = gidx*32+ idx;
7575
int CurrentOffSetB = gidy*32+ idx;
76-
76+
7777
A += gidx*32+ idx + idy*lda;
7878
B += gidy*32+ idx + idy*ldb;
79-
80-
79+
80+
8181
uint block_k = K >> 4;
82-
do
82+
do
8383
{
8484
__local float* plA = lA + idy*33+idx;
8585
__local float* plB = lB + idy*33+idx;
8686
barrier(CLK_LOCAL_MEM_FENCE);
87-
87+
8888
plB[0] = CurrentOffSetB>=N?0.0:B[0];
8989
plB[16] = CurrentOffSetB+16>=N?0.0:B[16];
90-
90+
9191
plA[0] = CurrentOffSetA>=M?0.0:A[0];
9292
plA[16] = CurrentOffSetA+16>=M?0.0:A[16];
9393

94-
94+
9595
barrier(CLK_LOCAL_MEM_FENCE);
9696
uint offA = idx;
9797
uint offB = idy;
@@ -126,23 +126,35 @@ __kernel void sgemm_Col_NT_B1_MX032_NX032_KX16_BRANCH (
126126
return;
127127

128128
C+=offset_x+offset_y*ldc;
129-
130-
int i = 0;
131-
do
132-
{
133-
C[0 ] = mad(alpha, rC[i][0], beta*C[0]);
134-
if(offset_y+16<N)
135-
C[16*ldc] = mad(alpha, rC[i][1], beta*C[16*ldc]);
136-
137-
C+=16;
138-
offset_x+=16;
139-
if(offset_x>=M )
140-
return;
141-
142-
143-
}
144-
while (++i < 2);
145129

130+
int i = 0;
131+
if (beta !=0 ) {
132+
do
133+
{
134+
C[0 ] = mad(alpha, rC[i][0], beta*C[0]);
135+
if(offset_y+16<N)
136+
C[16*ldc] = mad(alpha, rC[i][1], beta*C[16*ldc]);
137+
138+
C+=16;
139+
offset_x+=16;
140+
if(offset_x>=M )
141+
return;
142+
}
143+
while (++i < 2);
144+
} else {
145+
do
146+
{
147+
C[0 ] = alpha * rC[i][0];
148+
if(offset_y+16<N)
149+
C[16*ldc] = alpha * rC[i][1];
150+
151+
C+=16;
152+
offset_x+=16;
153+
if(offset_x>=M )
154+
return;
155+
}
156+
while (++i < 2);
157+
}
146158
}
147159
);
148160
#endif

src/library/blas/AutoGemm/UserGemmKernelSources/sgemm_Col_NT_B1_MX032_NX032_KX16_SINGLE_src.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -133,25 +133,26 @@ __kernel void sgemm_Col_NT_B1_MX032_NX032_KX16_SINGLE (
133133
int offset_x = gidx * 64 + idx;
134134
int offset_y = gidy * 64 + idy;
135135

136-
//if(offset_x>=M || offset_y>=N )
137-
// return;
138-
139136
C += offset_x + offset_y*ldc;
140137

141138
int i = 0;
142-
do
143-
{
144-
C[0] = mad(alpha, rC[i][0], beta*C[0]);
145-
C[16 * ldc] = mad(alpha, rC[i][1], beta*C[16 * ldc]);
146-
147-
148-
C += 16;
149-
offset_x += 16;
150-
//if(offset_x>=M )
151-
// return;
152-
153-
154-
} while (++i < 2);
139+
if (beta != 0) {
140+
do
141+
{
142+
C[0] = mad(alpha, rC[i][0], beta*C[0]);
143+
C[16 * ldc] = mad(alpha, rC[i][1], beta*C[16 * ldc]);
144+
C += 16;
145+
offset_x += 16;
146+
} while (++i < 2);
147+
} else {
148+
do
149+
{
150+
C[0] = alpha * rC[i][0];
151+
C[16 * ldc] = alpha * rC[i][1];
152+
C += 16;
153+
offset_x += 16;
154+
} while (++i < 2);
155+
}
155156

156157
}
157158
);

src/library/blas/AutoGemm/UserGemmKernelSources/sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_src.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,9 @@ __kernel void sgemm_Col_NT_B1_MX032_NX064_KX16_ROW (
145145
C += offset_x + offset_y*ldc;
146146

147147
int i = 0;
148-
do
149-
{
148+
if (beta != 0) {
149+
do
150+
{
150151
C[0] = mad(alpha, rC[i][0], beta*C[0]);
151152
C[16 * ldc] = mad(alpha, rC[i][1], beta*C[16 * ldc]);
152153
C[32 * ldc] = mad(alpha, rC[i][2], beta*C[32 * ldc]);
@@ -155,7 +156,20 @@ __kernel void sgemm_Col_NT_B1_MX032_NX064_KX16_ROW (
155156
offset_x += 16;
156157
//if(offset_x>=M )
157158
// return;
158-
} while (++i < 2);
159+
} while (++i < 2);
160+
} else {
161+
do
162+
{
163+
C[0] = alpha * rC[i][0];
164+
C[16 * ldc] = alpha * rC[i][1];
165+
C[32 * ldc] = alpha * rC[i][2];
166+
C[48 * ldc] = alpha * rC[i][3];
167+
C += 16;
168+
offset_x += 16;
169+
//if(offset_x>=M )
170+
// return;
171+
} while (++i < 2);
172+
}
159173
}
160174
);
161175
#endif

src/library/blas/AutoGemm/UserGemmKernelSources/sgemm_Col_NT_B1_MX064_NX032_KX16_COL_src.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,15 +143,21 @@ __kernel void sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN (
143143
C += offset_x + offset_y*ldc;
144144

145145
int i = 0;
146-
do
147-
{
146+
if (beta != 0) {
147+
do
148+
{
148149
C[0] = mad(alpha, rC[i][0], beta*C[0]);
149150
C[16 * ldc] = mad(alpha, rC[i][1], beta*C[16 * ldc]);
150-
151151
C += 16;
152-
153-
} while (++i < 4);
154-
152+
} while (++i < 4);
153+
} else {
154+
do
155+
{
156+
C[0] = alpha * rC[i][0];
157+
C[16 * ldc] = alpha * rC[i][1];
158+
C += 16;
159+
} while (++i < 4);
160+
}
155161
}
156162
);
157163
#endif

0 commit comments

Comments
 (0)