Skip to content

Commit ae53114

Browse files
committed
Merge pull request #221 from arrayfire/arrayfire-release-test
Bug fixes to AutoGEMM and DTRSM, DTRTRI
2 parents 3ec45fd + d32081a commit ae53114

13 files changed

+280
-216
lines changed

src/library/blas/AutoGemm/UserGemmKernelSources/sgemm_Col_NN_B0_MX096_NX096_KX16_src.cpp

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -161,47 +161,47 @@ __kernel void sgemm_Col_NN_B0_MX096_NX096_KX16 (
161161
C+= gidy*96*ldc;
162162
C+= idy*ldc;
163163

164-
C[0*ldc] = alpha*rC[0][0] + beta*C[0*ldc];
165-
C[16*ldc] = alpha*rC[0][1] + beta*C[16*ldc];
166-
C[32*ldc] = alpha*rC[0][2] + beta*C[32*ldc];
167-
C[48*ldc] = alpha*rC[0][3] + beta*C[48*ldc];
168-
C[64*ldc] = alpha*rC[0][4] + beta*C[64*ldc];
169-
C[80*ldc] = alpha*rC[0][5] + beta*C[80*ldc];
164+
C[0 *ldc] = alpha*rC[0][0];
165+
C[16*ldc] = alpha*rC[0][1];
166+
C[32*ldc] = alpha*rC[0][2];
167+
C[48*ldc] = alpha*rC[0][3];
168+
C[64*ldc] = alpha*rC[0][4];
169+
C[80*ldc] = alpha*rC[0][5];
170170
C+=16;
171-
C[0*ldc] = alpha*rC[1][0] + beta*C[0*ldc];
172-
C[16*ldc] = alpha*rC[1][1] + beta*C[16*ldc];
173-
C[32*ldc] = alpha*rC[1][2] + beta*C[32*ldc];
174-
C[48*ldc] = alpha*rC[1][3] + beta*C[48*ldc];
175-
C[64*ldc] = alpha*rC[1][4] + beta*C[64*ldc];
176-
C[80*ldc] = alpha*rC[1][5] + beta*C[80*ldc];
171+
C[0 *ldc] = alpha*rC[1][0];
172+
C[16*ldc] = alpha*rC[1][1];
173+
C[32*ldc] = alpha*rC[1][2];
174+
C[48*ldc] = alpha*rC[1][3];
175+
C[64*ldc] = alpha*rC[1][4];
176+
C[80*ldc] = alpha*rC[1][5];
177177
C+=16;
178-
C[0*ldc] = alpha*rC[2][0] + beta*C[0*ldc];
179-
C[16*ldc] = alpha*rC[2][1] + beta*C[16*ldc];
180-
C[32*ldc] = alpha*rC[2][2] + beta*C[32*ldc];
181-
C[48*ldc] = alpha*rC[2][3] + beta*C[48*ldc];
182-
C[64*ldc] = alpha*rC[2][4] + beta*C[64*ldc];
183-
C[80*ldc] = alpha*rC[2][5] + beta*C[80*ldc];
178+
C[0 *ldc] = alpha*rC[2][0];
179+
C[16*ldc] = alpha*rC[2][1];
180+
C[32*ldc] = alpha*rC[2][2];
181+
C[48*ldc] = alpha*rC[2][3];
182+
C[64*ldc] = alpha*rC[2][4];
183+
C[80*ldc] = alpha*rC[2][5];
184184
C+=16;
185-
C[0*ldc] = alpha*rC[3][0] + beta*C[0*ldc];
186-
C[16*ldc] = alpha*rC[3][1] + beta*C[16*ldc];
187-
C[32*ldc] = alpha*rC[3][2] + beta*C[32*ldc];
188-
C[48*ldc] = alpha*rC[3][3] + beta*C[48*ldc];
189-
C[64*ldc] = alpha*rC[3][4] + beta*C[64*ldc];
190-
C[80*ldc] = alpha*rC[3][5] + beta*C[80*ldc];
185+
C[0 *ldc] = alpha*rC[3][0];
186+
C[16*ldc] = alpha*rC[3][1];
187+
C[32*ldc] = alpha*rC[3][2];
188+
C[48*ldc] = alpha*rC[3][3];
189+
C[64*ldc] = alpha*rC[3][4];
190+
C[80*ldc] = alpha*rC[3][5];
191191
C+=16;
192-
C[0*ldc] = alpha*rC[4][0] + beta*C[0*ldc];
193-
C[16*ldc] = alpha*rC[4][1] + beta*C[16*ldc];
194-
C[32*ldc] = alpha*rC[4][2] + beta*C[32*ldc];
195-
C[48*ldc] = alpha*rC[4][3] + beta*C[48*ldc];
196-
C[64*ldc] = alpha*rC[4][4] + beta*C[64*ldc];
197-
C[80*ldc] = alpha*rC[4][5] + beta*C[80*ldc];
192+
C[0 *ldc] = alpha*rC[4][0];
193+
C[16*ldc] = alpha*rC[4][1];
194+
C[32*ldc] = alpha*rC[4][2];
195+
C[48*ldc] = alpha*rC[4][3];
196+
C[64*ldc] = alpha*rC[4][4];
197+
C[80*ldc] = alpha*rC[4][5];
198198
C+=16;
199-
C[0*ldc] = alpha*rC[5][0] + beta*C[0*ldc];
200-
C[16*ldc] = alpha*rC[5][1] + beta*C[16*ldc];
201-
C[32*ldc] = alpha*rC[5][2] + beta*C[32*ldc];
202-
C[48*ldc] = alpha*rC[5][3] + beta*C[48*ldc];
203-
C[64*ldc] = alpha*rC[5][4] + beta*C[64*ldc];
204-
C[80*ldc] = alpha*rC[5][5] + beta*C[80*ldc];
199+
C[0 *ldc] = alpha*rC[5][0];
200+
C[16*ldc] = alpha*rC[5][1];
201+
C[32*ldc] = alpha*rC[5][2];
202+
C[48*ldc] = alpha*rC[5][3];
203+
C[64*ldc] = alpha*rC[5][4];
204+
C[80*ldc] = alpha*rC[5][5];
205205

206206
}
207207
);

src/library/blas/AutoGemm/UserGemmKernelSources/sgemm_Col_NN_B1_MX032_NX032_KX16_BRANCH_src.cpp

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,46 +57,46 @@ __kernel void sgemm_Col_NN_B1_MX032_NX032_KX16_BRANCH (
5757
float rC[2][2] = { {(float)0} };
5858
float rA[1][2];
5959
float rB[1][2];
60-
6160

62-
61+
62+
6363
A += offsetA;
6464
B += offsetB;
6565
C+=offsetC;
66-
66+
6767
__local float lA[528];//16*32+16
6868
__local float lB[528];
69-
69+
7070
uint gidx = get_group_id(0);
7171
uint gidy = get_group_id(1);
7272
uint idx = get_local_id(0);
7373
uint idy = get_local_id(1);
74-
74+
7575
int CurrentOffSetA = gidx*32+ idx;
7676
int CurrentOffSetB = gidy*32+ idy;
7777

7878
A += gidx*32+ idx + idy*lda;
7979
B += gidy*32*ldb+ idx + idy*ldb;
80-
81-
80+
81+
8282
uint block_k = K >> 4;
83-
do
83+
do
8484
{
8585
__local float* plA = lA + idy*33+idx;
8686
__local float* plB = lB + idx*33+idy;
8787
barrier(CLK_LOCAL_MEM_FENCE);
88-
88+
8989
plB[0] = CurrentOffSetB>=N?0.0:B[0];
9090
plB[16] = CurrentOffSetB+16>=N?0.0:B[16*ldb];
91-
91+
9292
plA[0] = CurrentOffSetA>=M?0.0:A[0];
9393
plA[16] = CurrentOffSetA+16>=M?0.0:A[16];
9494

95-
95+
9696
barrier(CLK_LOCAL_MEM_FENCE);
9797
uint offA = idx;
9898
uint offB = idy;
99-
99+
100100
M2x2
101101
M2x2
102102
M2x2
@@ -123,26 +123,36 @@ __kernel void sgemm_Col_NN_B1_MX032_NX032_KX16_BRANCH (
123123
int offset_y = gidy*32+ idy;
124124
if(offset_x>=M || offset_y>=N )
125125
return;
126-
126+
127127
C+=offset_x+offset_y*ldc;
128-
129-
128+
129+
130130
int i = 0;
131-
do
132-
{
133-
C[0 ] = mad(alpha, rC[i][0], beta*C[0]);
134-
if(offset_y+16<N)
135-
C[16*ldc] = mad(alpha, rC[i][1], beta*C[16*ldc]);
136-
137-
C+=16;
138-
offset_x+=16;
139-
if(offset_x>=M )
140-
return;
141-
142-
143-
}
144-
while (++i < 2);
145-
131+
if (beta != 0) {
132+
do
133+
{
134+
C[0 ] = mad(alpha, rC[i][0], beta*C[0]);
135+
if(offset_y+16<N)
136+
C[16*ldc] = mad(alpha, rC[i][1], beta*C[16*ldc]);
137+
C+=16;
138+
offset_x+=16;
139+
if(offset_x>=M )
140+
return;
141+
}
142+
while (++i < 2);
143+
} else {
144+
do
145+
{
146+
C[0 ] = alpha * rC[i][0];
147+
if(offset_y+16<N)
148+
C[16*ldc] = alpha * rC[i][1];
149+
C+=16;
150+
offset_x+=16;
151+
if(offset_x>=M )
152+
return;
153+
}
154+
while (++i < 2);
155+
}
146156
}
147157

148158
);

src/library/blas/AutoGemm/UserGemmKernelSources/sgemm_Col_NT_B0_MX096_NX096_KX16_src.cpp

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -163,47 +163,47 @@ __kernel void sgemm_Col_NT_B0_MX096_NX096_KX16 (
163163
C+= gidy*96*ldc;
164164
C+= idy*ldc;
165165

166-
C[0*ldc] = alpha*rC[0][0] + beta*C[0*ldc];
167-
C[16*ldc] = alpha*rC[0][1] + beta*C[16*ldc];
168-
C[32*ldc] = alpha*rC[0][2] + beta*C[32*ldc];
169-
C[48*ldc] = alpha*rC[0][3] + beta*C[48*ldc];
170-
C[64*ldc] = alpha*rC[0][4] + beta*C[64*ldc];
171-
C[80*ldc] = alpha*rC[0][5] + beta*C[80*ldc];
166+
C[0*ldc] = alpha*rC[0][0];
167+
C[16*ldc] = alpha*rC[0][1];
168+
C[32*ldc] = alpha*rC[0][2];
169+
C[48*ldc] = alpha*rC[0][3];
170+
C[64*ldc] = alpha*rC[0][4];
171+
C[80*ldc] = alpha*rC[0][5];
172172
C+=16;
173-
C[0*ldc] = alpha*rC[1][0] + beta*C[0*ldc];
174-
C[16*ldc] = alpha*rC[1][1] + beta*C[16*ldc];
175-
C[32*ldc] = alpha*rC[1][2] + beta*C[32*ldc];
176-
C[48*ldc] = alpha*rC[1][3] + beta*C[48*ldc];
177-
C[64*ldc] = alpha*rC[1][4] + beta*C[64*ldc];
178-
C[80*ldc] = alpha*rC[1][5] + beta*C[80*ldc];
173+
C[0*ldc] = alpha*rC[1][0];
174+
C[16*ldc] = alpha*rC[1][1];
175+
C[32*ldc] = alpha*rC[1][2];
176+
C[48*ldc] = alpha*rC[1][3];
177+
C[64*ldc] = alpha*rC[1][4];
178+
C[80*ldc] = alpha*rC[1][5];
179179
C+=16;
180-
C[0*ldc] = alpha*rC[2][0] + beta*C[0*ldc];
181-
C[16*ldc] = alpha*rC[2][1] + beta*C[16*ldc];
182-
C[32*ldc] = alpha*rC[2][2] + beta*C[32*ldc];
183-
C[48*ldc] = alpha*rC[2][3] + beta*C[48*ldc];
184-
C[64*ldc] = alpha*rC[2][4] + beta*C[64*ldc];
185-
C[80*ldc] = alpha*rC[2][5] + beta*C[80*ldc];
180+
C[0*ldc] = alpha*rC[2][0];
181+
C[16*ldc] = alpha*rC[2][1];
182+
C[32*ldc] = alpha*rC[2][2];
183+
C[48*ldc] = alpha*rC[2][3];
184+
C[64*ldc] = alpha*rC[2][4];
185+
C[80*ldc] = alpha*rC[2][5];
186186
C+=16;
187-
C[0*ldc] = alpha*rC[3][0] + beta*C[0*ldc];
188-
C[16*ldc] = alpha*rC[3][1] + beta*C[16*ldc];
189-
C[32*ldc] = alpha*rC[3][2] + beta*C[32*ldc];
190-
C[48*ldc] = alpha*rC[3][3] + beta*C[48*ldc];
191-
C[64*ldc] = alpha*rC[3][4] + beta*C[64*ldc];
192-
C[80*ldc] = alpha*rC[3][5] + beta*C[80*ldc];
187+
C[0*ldc] = alpha*rC[3][0];
188+
C[16*ldc] = alpha*rC[3][1];
189+
C[32*ldc] = alpha*rC[3][2];
190+
C[48*ldc] = alpha*rC[3][3];
191+
C[64*ldc] = alpha*rC[3][4];
192+
C[80*ldc] = alpha*rC[3][5];
193193
C+=16;
194-
C[0*ldc] = alpha*rC[4][0] + beta*C[0*ldc];
195-
C[16*ldc] = alpha*rC[4][1] + beta*C[16*ldc];
196-
C[32*ldc] = alpha*rC[4][2] + beta*C[32*ldc];
197-
C[48*ldc] = alpha*rC[4][3] + beta*C[48*ldc];
198-
C[64*ldc] = alpha*rC[4][4] + beta*C[64*ldc];
199-
C[80*ldc] = alpha*rC[4][5] + beta*C[80*ldc];
194+
C[0*ldc] = alpha*rC[4][0];
195+
C[16*ldc] = alpha*rC[4][1];
196+
C[32*ldc] = alpha*rC[4][2];
197+
C[48*ldc] = alpha*rC[4][3];
198+
C[64*ldc] = alpha*rC[4][4];
199+
C[80*ldc] = alpha*rC[4][5];
200200
C+=16;
201-
C[0*ldc] = alpha*rC[5][0] + beta*C[0*ldc];
202-
C[16*ldc] = alpha*rC[5][1] + beta*C[16*ldc];
203-
C[32*ldc] = alpha*rC[5][2] + beta*C[32*ldc];
204-
C[48*ldc] = alpha*rC[5][3] + beta*C[48*ldc];
205-
C[64*ldc] = alpha*rC[5][4] + beta*C[64*ldc];
206-
C[80*ldc] = alpha*rC[5][5] + beta*C[80*ldc];
201+
C[0*ldc] = alpha*rC[5][0];
202+
C[16*ldc] = alpha*rC[5][1];
203+
C[32*ldc] = alpha*rC[5][2];
204+
C[48*ldc] = alpha*rC[5][3];
205+
C[64*ldc] = alpha*rC[5][4];
206+
C[80*ldc] = alpha*rC[5][5];
207207

208208
}
209209
);

src/library/blas/AutoGemm/UserGemmKernelSources/sgemm_Col_NT_B1_MX032_NX032_KX16_BRANCH_src.cpp

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -57,41 +57,41 @@ __kernel void sgemm_Col_NT_B1_MX032_NX032_KX16_BRANCH (
5757
float rC[2][2] = { {(float)0} };
5858
float rA[1][2];
5959
float rB[1][2];
60-
61-
60+
61+
6262
A += offsetA;
6363
B += offsetB;
6464
C+=offsetC;
65-
65+
6666
__local float lA[528];//16*32+16
6767
__local float lB[528];
68-
68+
6969
uint gidx = get_group_id(0);
7070
uint gidy = get_group_id(1);
7171
uint idx = get_local_id(0);
7272
uint idy = get_local_id(1);
73-
73+
7474
int CurrentOffSetA = gidx*32+ idx;
7575
int CurrentOffSetB = gidy*32+ idx;
76-
76+
7777
A += gidx*32+ idx + idy*lda;
7878
B += gidy*32+ idx + idy*ldb;
79-
80-
79+
80+
8181
uint block_k = K >> 4;
82-
do
82+
do
8383
{
8484
__local float* plA = lA + idy*33+idx;
8585
__local float* plB = lB + idy*33+idx;
8686
barrier(CLK_LOCAL_MEM_FENCE);
87-
87+
8888
plB[0] = CurrentOffSetB>=N?0.0:B[0];
8989
plB[16] = CurrentOffSetB+16>=N?0.0:B[16];
90-
90+
9191
plA[0] = CurrentOffSetA>=M?0.0:A[0];
9292
plA[16] = CurrentOffSetA+16>=M?0.0:A[16];
9393

94-
94+
9595
barrier(CLK_LOCAL_MEM_FENCE);
9696
uint offA = idx;
9797
uint offB = idy;
@@ -126,23 +126,35 @@ __kernel void sgemm_Col_NT_B1_MX032_NX032_KX16_BRANCH (
126126
return;
127127

128128
C+=offset_x+offset_y*ldc;
129-
130-
int i = 0;
131-
do
132-
{
133-
C[0 ] = mad(alpha, rC[i][0], beta*C[0]);
134-
if(offset_y+16<N)
135-
C[16*ldc] = mad(alpha, rC[i][1], beta*C[16*ldc]);
136-
137-
C+=16;
138-
offset_x+=16;
139-
if(offset_x>=M )
140-
return;
141-
142-
143-
}
144-
while (++i < 2);
145129

130+
int i = 0;
131+
if (beta !=0 ) {
132+
do
133+
{
134+
C[0 ] = mad(alpha, rC[i][0], beta*C[0]);
135+
if(offset_y+16<N)
136+
C[16*ldc] = mad(alpha, rC[i][1], beta*C[16*ldc]);
137+
138+
C+=16;
139+
offset_x+=16;
140+
if(offset_x>=M )
141+
return;
142+
}
143+
while (++i < 2);
144+
} else {
145+
do
146+
{
147+
C[0 ] = alpha * rC[i][0];
148+
if(offset_y+16<N)
149+
C[16*ldc] = alpha * rC[i][1];
150+
151+
C+=16;
152+
offset_x+=16;
153+
if(offset_x>=M )
154+
return;
155+
}
156+
while (++i < 2);
157+
}
146158
}
147159
);
148160
#endif

0 commit comments

Comments
 (0)