Skip to content

Commit de63675

Browse files
authored
Add early returns and fix sign errors in workspace calculations
1 parent d64cc2b commit de63675

File tree

8 files changed

+65
-27
lines changed

8 files changed

+65
-27
lines changed

relapack/src/cgbtrf.c

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void RELAPACK_cgbtrf(
3636
return;
3737
}
3838

39+
if (*m == 0 || *n == 0) return;
3940
// Constant
4041
const float ZERO[] = { 0., 0. };
4142

@@ -56,10 +57,10 @@ void RELAPACK_cgbtrf(
5657

5758
// Allocate work space
5859
const blasint n1 = CREC_SPLIT(*n);
59-
const blasint mWorkl = (kv > n1) ? MAX(1, *m - *kl) : kv;
60-
const blasint nWorkl = (kv > n1) ? n1 : kv;
61-
const blasint mWorku = (*kl > n1) ? n1 : *kl;
62-
const blasint nWorku = (*kl > n1) ? MAX(0, *n - *kl) : *kl;
60+
const blasint mWorkl = abs ( (kv > n1) ? MAX(1, *m - *kl) : kv);
61+
const blasint nWorkl = abs ( (kv > n1) ? n1 : kv);
62+
const blasint mWorku = abs ((*kl > n1) ? n1 : *kl);
63+
const blasint nWorku = abs ((*kl > n1) ? MAX(0, *n - *kl) : *kl);
6364
float *Workl = malloc(mWorkl * nWorkl * 2 * sizeof(float));
6465
float *Worku = malloc(mWorku * nWorku * 2 * sizeof(float));
6566
LAPACK(claset)("L", &mWorkl, &nWorkl, ZERO, ZERO, Workl, &mWorkl);
@@ -82,7 +83,7 @@ static void RELAPACK_cgbtrf_rec(
8283
blasint *info
8384
) {
8485

85-
if (*n <= MAX(CROSSOVER_CGBTRF, 1)) {
86+
if (*n <= MAX(CROSSOVER_CGBTRF, 1)|| *n > *kl || *ldAb == 1) {
8687
// Unblocked
8788
LAPACK(cgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
8889
return;

relapack/src/cpbtrf.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ void RELAPACK_cpbtrf(
3535
return;
3636
}
3737

38+
if (*n == 0) return;
39+
3840
// Clean char * arguments
3941
const char cleanuplo = lower ? 'L' : 'U';
4042

@@ -43,8 +45,8 @@ void RELAPACK_cpbtrf(
4345

4446
// Allocate work space
4547
const blasint n1 = CREC_SPLIT(*n);
46-
const blasint mWork = (*kd > n1) ? (lower ? *n - *kd : n1) : *kd;
47-
const blasint nWork = (*kd > n1) ? (lower ? n1 : *n - *kd) : *kd;
48+
const blasint mWork = abs((*kd > n1) ? (lower ? *n - *kd : n1) : *kd);
49+
const blasint nWork = abs((*kd > n1) ? (lower ? n1 : *n - *kd) : *kd);
4850
float *Work = malloc(mWork * nWork * 2 * sizeof(float));
4951
LAPACK(claset)(uplo, &mWork, &nWork, ZERO, ZERO, Work, &mWork);
5052

@@ -64,7 +66,7 @@ static void RELAPACK_cpbtrf_rec(
6466
blasint *info
6567
){
6668

67-
if (*n <= MAX(CROSSOVER_CPBTRF, 1)) {
69+
if (*n <= MAX(CROSSOVER_CPBTRF, 1) || *ldAb==1) {
6870
// Unblocked
6971
LAPACK(cpbtf2)(uplo, n, kd, Ab, ldAb, info);
7072
return;
@@ -148,7 +150,7 @@ static void RELAPACK_cpbtrf_rec(
148150
}
149151

150152
// recursion(A_BR)
151-
if (*kd > n1)
153+
if (*kd > n1 && ldA != 0)
152154
RELAPACK_cpotrf(uplo, &n2, A_BR, ldA, info);
153155
else
154156
RELAPACK_cpbtrf_rec(uplo, &n2, kd, Ab_BR, ldAb, Work, ldWork, info);

relapack/src/dgbtrf.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ void RELAPACK_dgbtrf(
3636
return;
3737
}
3838

39+
if (*m == 0 || *n == 0) return;
40+
3941
// Constant
4042
const double ZERO[] = { 0. };
4143

@@ -83,7 +85,7 @@ static void RELAPACK_dgbtrf_rec(
8385
blasint *info
8486
) {
8587

86-
if (*n <= MAX(CROSSOVER_DGBTRF, 1)) {
88+
if (*n <= MAX(CROSSOVER_DGBTRF, 1) || *n > *kl || *ldAb == 1) {
8789
// Unblocked
8890
LAPACK(dgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
8991
return;
@@ -195,6 +197,7 @@ static void RELAPACK_dgbtrf_rec(
195197
// Worku = A_TRr
196198
LAPACK(dlacpy)("L", &m1, &n22, A_TRr, ldA, Worku, ldWorku);
197199
// Worku = A_TL \ Worku
200+
if (ldWorku <= 0) return;
198201
BLAS(dtrsm)("L", "L", "N", "U", &m1, &n22, ONE, A_TL, ldA, Worku, ldWorku);
199202
// A_TRr = Worku
200203
LAPACK(dlacpy)("L", &m1, &n22, Worku, ldWorku, A_TRr, ldA);

relapack/src/dpbtrf.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ void RELAPACK_dpbtrf(
3535
return;
3636
}
3737

38+
if (*n == 0) return;
39+
3840
// Clean char * arguments
3941
const char cleanuplo = lower ? 'L' : 'U';
4042

@@ -43,8 +45,8 @@ void RELAPACK_dpbtrf(
4345

4446
// Allocate work space
4547
const blasint n1 = DREC_SPLIT(*n);
46-
const blasint mWork = (*kd > n1) ? (lower ? *n - *kd : n1) : *kd;
47-
const blasint nWork = (*kd > n1) ? (lower ? n1 : *n - *kd) : *kd;
48+
const blasint mWork = abs((*kd > n1) ? (lower ? *n - *kd : n1) : *kd);
49+
const blasint nWork = abs((*kd > n1) ? (lower ? n1 : *n - *kd) : *kd);
4850
double *Work = malloc(mWork * nWork * sizeof(double));
4951
LAPACK(dlaset)(uplo, &mWork, &nWork, ZERO, ZERO, Work, &mWork);
5052

@@ -64,7 +66,7 @@ static void RELAPACK_dpbtrf_rec(
6466
blasint *info
6567
){
6668

67-
if (*n <= MAX(CROSSOVER_DPBTRF, 1)) {
69+
if (*n <= MAX(CROSSOVER_DPBTRF, 1) || *ldAb == 1) {
6870
// Unblocked
6971
LAPACK(dpbtf2)(uplo, n, kd, Ab, ldAb, info);
7072
return;
@@ -148,7 +150,7 @@ static void RELAPACK_dpbtrf_rec(
148150
}
149151

150152
// recursion(A_BR)
151-
if (*kd > n1)
153+
if (*kd > n1 && ldA != 0)
152154
RELAPACK_dpotrf(uplo, &n2, A_BR, ldA, info);
153155
else
154156
RELAPACK_dpbtrf_rec(uplo, &n2, kd, Ab_BR, ldAb, Work, ldWork, info);

relapack/src/sgbtrf.c

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ void RELAPACK_sgbtrf(
3535
return;
3636
}
3737

38+
if (*m == 0 || *n == 0) return;
39+
40+
if (*ldAb == 1) {
41+
LAPACK(sgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
42+
return;
43+
}
44+
3845
// Constant
3946
const float ZERO[] = { 0. };
4047

@@ -82,8 +89,9 @@ static void RELAPACK_sgbtrf_rec(
8289
blasint *info
8390
) {
8491

92+
if (*m == 0 || *n == 0) return;
8593

86-
if (*n <= MAX(CROSSOVER_SGBTRF, 1)) {
94+
if ( *n <= MAX(CROSSOVER_SGBTRF, 1) || *n > *kl || *ldAb == 1) {
8795
// Unblocked
8896
LAPACK(sgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
8997
return;
@@ -160,7 +168,7 @@ static void RELAPACK_sgbtrf_rec(
160168

161169
// recursion(Ab_L, ipiv_T)
162170
RELAPACK_sgbtrf_rec(m, &n1, kl, ku, Ab_L, ldAb, ipiv_T, Workl, ldWorkl, Worku, ldWorku, info);
163-
171+
if (*info) return;
164172
// Workl = A_BLb
165173
LAPACK(slacpy)("U", &m22, &n1, A_BLb, ldA, Workl, ldWorkl);
166174

@@ -222,8 +230,8 @@ static void RELAPACK_sgbtrf_rec(
222230

223231
// recursion(Ab_BR, ipiv_B)
224232
//cause of infinite recursion here ?
225-
// RELAPACK_sgbtrf_rec(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, Workl, ldWorkl, Worku, ldWorku, info);
226-
LAPACK(sgbtf2)(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, info);
233+
RELAPACK_sgbtrf_rec(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, Workl, ldWorkl, Worku, ldWorku, info);
234+
// LAPACK(sgbtf2)(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, info);
227235
if (*info)
228236
*info += n1;
229237
// shift pivots

relapack/src/spbtrf.c

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ void RELAPACK_spbtrf(
3535
return;
3636
}
3737

38+
39+
if (*n == 0) return;
40+
3841
// Clean char * arguments
3942
const char cleanuplo = lower ? 'L' : 'U';
4043

@@ -43,8 +46,8 @@ void RELAPACK_spbtrf(
4346

4447
// Allocate work space
4548
const blasint n1 = SREC_SPLIT(*n);
46-
const blasint mWork = (*kd > n1) ? (lower ? *n - *kd : n1) : *kd;
47-
const blasint nWork = (*kd > n1) ? (lower ? n1 : *n - *kd) : *kd;
49+
const blasint mWork = abs( (*kd > n1) ? (lower ? *n - *kd : n1) : *kd);
50+
const blasint nWork = abs((*kd > n1) ? (lower ? n1 : *n - *kd) : *kd);
4851
float *Work = malloc(mWork * nWork * sizeof(float));
4952
LAPACK(slaset)(uplo, &mWork, &nWork, ZERO, ZERO, Work, &mWork);
5053

@@ -64,7 +67,9 @@ static void RELAPACK_spbtrf_rec(
6467
blasint *info
6568
){
6669

67-
if (*n <= MAX(CROSSOVER_SPBTRF, 1)) {
70+
if (*n == 0 ) return;
71+
72+
if ( *n <= MAX(CROSSOVER_SPBTRF, 1) || *ldAb == 1) {
6873
// Unblocked
6974
LAPACK(spbtf2)(uplo, n, kd, Ab, ldAb, info);
7075
return;
@@ -148,7 +153,7 @@ static void RELAPACK_spbtrf_rec(
148153
}
149154

150155
// recursion(A_BR)
151-
if (*kd > n1)
156+
if (*kd > n1 && ldA != 0)
152157
RELAPACK_spotrf(uplo, &n2, A_BR, ldA, info);
153158
else
154159
RELAPACK_spbtrf_rec(uplo, &n2, kd, Ab_BR, ldAb, Work, ldWork, info);

relapack/src/zgbtrf.c

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ void RELAPACK_zgbtrf(
3636
return;
3737
}
3838

39+
if (*m == 0 || *n == 0) return;
40+
3941
// Constant
4042
const double ZERO[] = { 0., 0. };
4143

@@ -82,7 +84,7 @@ static void RELAPACK_zgbtrf_rec(
8284
blasint *info
8385
) {
8486

85-
if (*n <= MAX(CROSSOVER_ZGBTRF, 1)) {
87+
if (*n <= MAX(CROSSOVER_ZGBTRF, 1) || *n > *kl || *ldAb == 1) {
8688
// Unblocked
8789
LAPACK(zgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
8890
return;
@@ -92,6 +94,7 @@ static void RELAPACK_zgbtrf_rec(
9294
const double ONE[] = { 1., 0. };
9395
const double MONE[] = { -1., 0. };
9496
const blasint iONE[] = { 1 };
97+
const blasint min11 = -11;
9598

9699
// Loop iterators
97100
blasint i, j;
@@ -158,6 +161,7 @@ static void RELAPACK_zgbtrf_rec(
158161

159162
// recursion(Ab_L, ipiv_T)
160163
RELAPACK_zgbtrf_rec(m, &n1, kl, ku, Ab_L, ldAb, ipiv_T, Workl, ldWorkl, Worku, ldWorku, info);
164+
if (*info) return;
161165

162166
// Workl = A_BLb
163167
LAPACK(zlacpy)("U", &m22, &n1, A_BLb, ldA, Workl, ldWorkl);
@@ -193,11 +197,21 @@ static void RELAPACK_zgbtrf_rec(
193197
}
194198

195199
// A_TRl = A_TL \ A_TRl
200+
if (*ldA < MAX(1,m1)) {
201+
LAPACK(xerbla)("ZGBTRF", &min11, strlen("ZGBTRF"));
202+
return;
203+
} else {
196204
BLAS(ztrsm)("L", "L", "N", "U", &m1, &n21, ONE, A_TL, ldA, A_TRl, ldA);
205+
}
197206
// Worku = A_TRr
198207
LAPACK(zlacpy)("L", &m1, &n22, A_TRr, ldA, Worku, ldWorku);
199208
// Worku = A_TL \ Worku
209+
if (*ldWorku < MAX(1,m1)) {
210+
LAPACK(xerbla)("ZGBTRF", &min11, strlen("ZGBTRF"));
211+
return;
212+
} else {
200213
BLAS(ztrsm)("L", "L", "N", "U", &m1, &n22, ONE, A_TL, ldA, Worku, ldWorku);
214+
}
201215
// A_TRr = Worku
202216
LAPACK(zlacpy)("L", &m1, &n22, Worku, ldWorku, A_TRr, ldA);
203217
// A_BRtl = A_BRtl - A_BLt * A_TRl

relapack/src/zpbtrf.c

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ void RELAPACK_zpbtrf(
3535
return;
3636
}
3737

38+
if (*n == 0) return;
39+
3840
// Clean char * arguments
3941
const char cleanuplo = lower ? 'L' : 'U';
4042

@@ -43,9 +45,10 @@ void RELAPACK_zpbtrf(
4345

4446
// Allocate work space
4547
const blasint n1 = ZREC_SPLIT(*n);
46-
const blasint mWork = (*kd > n1) ? (lower ? *n - *kd : n1) : *kd;
47-
const blasint nWork = (*kd > n1) ? (lower ? n1 : *n - *kd) : *kd;
48+
const blasint mWork = abs((*kd > n1) ? (lower ? *n - *kd : n1) : *kd);
49+
const blasint nWork = abs((*kd > n1) ? (lower ? n1 : *n - *kd) : *kd);
4850
double *Work = malloc(mWork * nWork * 2 * sizeof(double));
51+
4952
LAPACK(zlaset)(uplo, &mWork, &nWork, ZERO, ZERO, Work, &mWork);
5053

5154
// Recursive kernel
@@ -64,7 +67,7 @@ static void RELAPACK_zpbtrf_rec(
6467
blasint *info
6568
){
6669

67-
if (*n <= MAX(CROSSOVER_ZPBTRF, 1)) {
70+
if (*n <= MAX(CROSSOVER_ZPBTRF, 1) || *ldAb == 1) {
6871
// Unblocked
6972
LAPACK(zpbtf2)(uplo, n, kd, Ab, ldAb, info);
7073
return;
@@ -148,7 +151,7 @@ static void RELAPACK_zpbtrf_rec(
148151
}
149152

150153
// recursion(A_BR)
151-
if (*kd > n1)
154+
if (*kd > n1 && ldA != 0)
152155
RELAPACK_zpotrf(uplo, &n2, A_BR, ldA, info);
153156
else
154157
RELAPACK_zpbtrf_rec(uplo, &n2, kd, Ab_BR, ldAb, Work, ldWork, info);

0 commit comments

Comments
 (0)