Skip to content

Commit 72b5b73

Browse files
authored
Merge pull request #2850 from xiaojiayuan111/develop
fix a bug of trmm
2 parents dfaafd3 + 06cf73a commit 72b5b73

File tree

1 file changed

+42
-6
lines changed

1 file changed

+42
-6
lines changed

driver/level3/trmm_L.c

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
122122
if (min_l > GEMM_Q) min_l = GEMM_Q;
123123
min_i = min_l;
124124
if (min_i > GEMM_P) min_i = GEMM_P;
125+
if( min_i > GEMM_UNROLL_M){
126+
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
127+
}
125128

126129
START_RPCC();
127130

@@ -161,9 +164,12 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
161164
}
162165

163166

164-
for(is = min_i; is < min_l; is += GEMM_P){
167+
for(is = min_i; is < min_l; is += min_i){
165168
min_i = min_l - is;
166169
if (min_i > GEMM_P) min_i = GEMM_P;
170+
if( min_i > GEMM_UNROLL_M){
171+
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
172+
}
167173

168174
START_RPCC();
169175

@@ -192,6 +198,10 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
192198
if (min_l > GEMM_Q) min_l = GEMM_Q;
193199
min_i = ls;
194200
if (min_i > GEMM_P) min_i = GEMM_P;
201+
if( min_i > GEMM_UNROLL_M){
202+
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
203+
}
204+
195205

196206
START_RPCC();
197207

@@ -231,9 +241,12 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
231241
STOP_RPCC(gemmcost);
232242
}
233243

234-
for(is = min_i; is < ls; is += GEMM_P){
244+
for(is = min_i; is < ls; is += min_i){
235245
min_i = ls - is;
236246
if (min_i > GEMM_P) min_i = GEMM_P;
247+
if( min_i > GEMM_UNROLL_M){
248+
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
249+
}
237250

238251
START_RPCC();
239252

@@ -256,9 +269,12 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
256269
STOP_RPCC(gemmcost);
257270
}
258271

259-
for(is = ls; is < ls + min_l; is += GEMM_P){
272+
for(is = ls; is < ls + min_l; is += min_i){
260273
min_i = ls + min_l - is;
261274
if (min_i > GEMM_P) min_i = GEMM_P;
275+
if( min_i > GEMM_UNROLL_M){
276+
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
277+
}
262278

263279
START_RPCC();
264280

@@ -287,6 +303,10 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
287303
if (min_l > GEMM_Q) min_l = GEMM_Q;
288304
min_i = min_l;
289305
if (min_i > GEMM_P) min_i = GEMM_P;
306+
if (min_i > GEMM_UNROLL_M){
307+
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
308+
}
309+
290310

291311
START_RPCC();
292312

@@ -327,9 +347,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
327347
STOP_RPCC(trmmcost);
328348
}
329349

330-
for(is = m - min_l + min_i; is < m; is += GEMM_P){
350+
for(is = m - min_l + min_i; is < m; is += min_i){
331351
min_i = m - is;
332352
if (min_i > GEMM_P) min_i = GEMM_P;
353+
if (min_i > GEMM_UNROLL_M){
354+
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
355+
}
356+
357+
333358

334359
START_RPCC();
335360

@@ -357,6 +382,10 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
357382
if (min_l > GEMM_Q) min_l = GEMM_Q;
358383
min_i = min_l;
359384
if (min_i > GEMM_P) min_i = GEMM_P;
385+
if (min_i > GEMM_UNROLL_M){
386+
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
387+
}
388+
360389

361390
START_RPCC();
362391

@@ -397,9 +426,13 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
397426
STOP_RPCC(trmmcost);
398427
}
399428

400-
for(is = ls - min_l + min_i; is < ls; is += GEMM_P){
429+
for(is = ls - min_l + min_i; is < ls; is += min_i){
401430
min_i = ls - is;
402431
if (min_i > GEMM_P) min_i = GEMM_P;
432+
if (min_i > GEMM_UNROLL_M){
433+
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
434+
}
435+
403436

404437
START_RPCC();
405438

@@ -423,9 +456,12 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
423456
}
424457

425458

426-
for(is = ls; is < m; is += GEMM_P){
459+
for(is = ls; is < m; is += min_i){
427460
min_i = m - is;
428461
if (min_i > GEMM_P) min_i = GEMM_P;
462+
if (min_i > GEMM_UNROLL_M){
463+
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
464+
}
429465

430466
START_RPCC();
431467

0 commit comments

Comments
 (0)