Skip to content

Commit 4e529f2

Browse files
shakedregevpelesh
andauthored
Properly address CSR and CSC discrepancies across all solvers.
Fixed discrepancies with minimal overhead. --------- Co-authored-by: shakedregev <shakedregev@users.noreply.github.com> Co-authored-by: pelesh <peless@ornl.gov>
1 parent 885edfa commit 4e529f2

14 files changed

+468753
-88
lines changed

resolve/LinSolverDirectCuSolverRf.cpp

Lines changed: 114 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#include "LinSolverDirectCuSolverRf.hpp"
22

3+
#include <algorithm>
34
#include <cassert>
5+
#include <cstring> // includes memcpy
6+
#include <vector>
47

58
#include <resolve/matrix/Csc.hpp>
69
#include <resolve/matrix/Csr.hpp>
@@ -44,6 +47,110 @@ namespace ReSolve
4447
mem_.deleteOnDevice(d_T_);
4548
}
4649

50+
/**
51+
* @brief Setup the cuSolverRf factorization with factors already in CSR
52+
*
53+
* Sets up the cuSolverRf factorization for the given matrix A and its
54+
* L and U factors. The permutation vectors P and Q are also set up.
55+
*
56+
* @param[in] A - pointer to the matrix A
57+
* @param[in] L - pointer to the lower triangular factor L in CSR
58+
* @param[in] U - pointer to the upper triangular factor U in CSR
59+
* @param[in] P - pointer to the permutation vector P
60+
* @param[in] Q - pointer to the permutation vector Q
61+
* @param[in] rhs - pointer to the right-hand side vector (optional)
62+
*
63+
* @pre The matrix A is in CSR format.
64+
*/
65+
66+
int LinSolverDirectCuSolverRf::setupCsr(matrix::Sparse* A,
67+
matrix::Sparse* L,
68+
matrix::Sparse* U,
69+
index_type* P,
70+
index_type* Q,
71+
vector_type* /* rhs */)
72+
{
73+
assert(A->getSparseFormat() == matrix::Sparse::COMPRESSED_SPARSE_ROW && "Matrix A has to be in CSR format for cusolverRf input.\n");
74+
assert(L->getSparseFormat() == U->getSparseFormat() && "Matrices L and U have to be in the same format for cusolverRf input.\n");
75+
assert(L->getSparseFormat() == matrix::Sparse::COMPRESSED_SPARSE_ROW && "Matrices L and U have to be in CSR format for cusolverRf input.\n");
76+
int error_sum = 0;
77+
this->A_ = A;
78+
index_type n = A_->getNumRows();
79+
80+
// Remember - P and Q are generally CPU variables!
81+
// Factorization data is stored in the handle.
82+
// If function is called again, destroy the old handle to get rid of old data.
83+
if (setup_completed_)
84+
{
85+
cusolverRfDestroy(handle_cusolverrf_);
86+
cusolverRfCreate(&handle_cusolverrf_);
87+
}
88+
89+
if (d_P_ == nullptr)
90+
{
91+
mem_.allocateArrayOnDevice(&d_P_, n);
92+
}
93+
94+
if (d_Q_ == nullptr)
95+
{
96+
mem_.allocateArrayOnDevice(&d_Q_, n);
97+
}
98+
99+
if (d_T_ != nullptr)
100+
{
101+
mem_.deleteOnDevice(d_T_);
102+
}
103+
104+
mem_.allocateArrayOnDevice(&d_T_, n);
105+
106+
mem_.copyArrayHostToDevice(d_P_, P, n);
107+
mem_.copyArrayHostToDevice(d_Q_, Q, n);
108+
109+
status_cusolverrf_ = cusolverRfSetResetValuesFastMode(handle_cusolverrf_, CUSOLVERRF_RESET_VALUES_FAST_MODE_ON);
110+
error_sum += status_cusolverrf_;
111+
// sort L and U columns
112+
for (index_type i = 0; i < n; ++i)
113+
{
114+
std::sort(L->getColData(memory::HOST) + L->getRowData(memory::HOST)[i],
115+
L->getColData(memory::HOST) + L->getRowData(memory::HOST)[i + 1]);
116+
std::sort(U->getColData(memory::HOST) + U->getRowData(memory::HOST)[i],
117+
U->getColData(memory::HOST) + U->getRowData(memory::HOST)[i + 1]);
118+
}
119+
L->setUpdated(memory::HOST);
120+
U->setUpdated(memory::HOST);
121+
L->syncData(memory::DEVICE);
122+
U->syncData(memory::DEVICE);
123+
status_cusolverrf_ = cusolverRfSetupDevice(n,
124+
A_->getNnz(),
125+
A_->getRowData(memory::DEVICE),
126+
A_->getColData(memory::DEVICE),
127+
A_->getValues(memory::DEVICE),
128+
L->getNnz(),
129+
L->getRowData(memory::DEVICE),
130+
L->getColData(memory::DEVICE),
131+
L->getValues(memory::DEVICE),
132+
U->getNnz(),
133+
U->getRowData(memory::DEVICE),
134+
U->getColData(memory::DEVICE),
135+
U->getValues(memory::DEVICE),
136+
d_P_,
137+
d_Q_,
138+
handle_cusolverrf_);
139+
error_sum += status_cusolverrf_;
140+
mem_.deviceSynchronize();
141+
status_cusolverrf_ = cusolverRfAnalyze(handle_cusolverrf_);
142+
error_sum += status_cusolverrf_;
143+
const cusolverRfFactorization_t fact_alg =
144+
CUSOLVERRF_FACTORIZATION_ALG0; // 0 - default, 1 or 2
145+
const cusolverRfTriangularSolve_t solve_alg =
146+
CUSOLVERRF_TRIANGULAR_SOLVE_ALG1; // 1- default, 2 or 3
147+
148+
this->setAlgorithms(fact_alg, solve_alg);
149+
150+
setup_completed_ = true;
151+
return error_sum;
152+
}
153+
47154
/**
48155
* @brief Setup the cuSolverRf factorization
49156
*
@@ -90,25 +197,26 @@ namespace ReSolve
90197
switch (L->getSparseFormat())
91198
{
92199
case matrix::Sparse::COMPRESSED_SPARSE_COLUMN:
93-
// std::cout << "converting L and U factors from CSC to CSR format ...\n";
94200
L_csc = static_cast<matrix::Csc*>(L);
95201
U_csc = static_cast<matrix::Csc*>(U);
96202
L_csr = new matrix::Csr(L_csc->getNumRows(), L_csc->getNumColumns(), L_csc->getNnz());
97203
U_csr = new matrix::Csr(U_csc->getNumRows(), U_csc->getNumColumns(), U_csc->getNnz());
98204
csc2csr(L_csc, L_csr);
99205
csc2csr(U_csc, U_csr);
100-
L_csr->syncData(memory::DEVICE);
101-
U_csr->syncData(memory::DEVICE);
102206
break;
103207
case matrix::Sparse::COMPRESSED_SPARSE_ROW:
104-
L_csr = dynamic_cast<matrix::Csr*>(L);
105-
U_csr = dynamic_cast<matrix::Csr*>(U);
208+
L_csr = static_cast<matrix::Csr*>(L);
209+
U_csr = static_cast<matrix::Csr*>(U);
210+
L_csr->setUpdated(memory::HOST);
211+
U_csr->setUpdated(memory::HOST);
106212
break;
107213
default:
108214
out::error() << "Matrix type for L and U factors not recognized!\n";
109215
out::error() << "Refactorization not completed.\n";
110216
return 1;
111217
}
218+
L_csr->syncData(memory::DEVICE);
219+
U_csr->syncData(memory::DEVICE);
112220

113221
if (d_P_ == nullptr)
114222
{
@@ -182,9 +290,6 @@ namespace ReSolve
182290
default:
183291
break;
184292
}
185-
// delete L_csr;
186-
// delete U_csr;
187-
188293
return error_sum;
189294
}
190295

@@ -296,6 +401,7 @@ namespace ReSolve
296401
A_->getNumRows(),
297402
x->getData(memory::DEVICE),
298403
A_->getNumRows());
404+
x->setDataUpdated(memory::DEVICE);
299405
return status_cusolverrf_;
300406
}
301407

resolve/LinSolverDirectCuSolverRf.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ namespace ReSolve
3939
index_type* Q,
4040
vector_type* rhs = nullptr) override;
4141

42+
int setupCsr(matrix::Sparse* A,
43+
matrix::Sparse* L,
44+
matrix::Sparse* U,
45+
index_type* P,
46+
index_type* Q,
47+
vector_type* rhs = nullptr);
48+
4249
int refactorize() override;
4350
int solve(vector_type* rhs, vector_type* x) override;
4451
int solve(vector_type* rhs) override; // rhs overwritten by solution

resolve/LinSolverDirectKLU.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,9 @@ namespace ReSolve
390390
/**
391391
* @brief Get the permutation vector P.
392392
*
393+
* Due to KLU's internal CSC storage, the P vector is obtained from Symbolic_->Q,
394+
* to keep things consistent with our CSR storage convention.
395+
*
393396
* @return P permutation vector
394397
*/
395398
index_type* LinSolverDirectKLU::getPOrdering()
@@ -398,7 +401,8 @@ namespace ReSolve
398401
{
399402
P_ = new index_type[A_->getNumRows()];
400403
size_t nrows = static_cast<size_t>(A_->getNumRows());
401-
std::memcpy(P_, Numeric_->Pnum, nrows * sizeof(index_type));
404+
std::memcpy(P_, Symbolic_->Q, nrows * sizeof(index_type)); // KLU's CSC Symbolic_->Q is the CSR P vector.
405+
// Only a symbolic factorization is needed to get Q, because there is only row pivoting for the numeric factorization.
402406
return P_;
403407
}
404408
else
@@ -410,6 +414,9 @@ namespace ReSolve
410414
/**
411415
* @brief Get the permutation vector Q.
412416
*
417+
* Due to KLU's internal CSC storage, the Q vector is obtained from Numeric_->Pnum,
418+
* to keep things consistent with our CSR storage convention.
419+
*
413420
* @return Q permutation vector
414421
*/
415422
index_type* LinSolverDirectKLU::getQOrdering()
@@ -418,7 +425,8 @@ namespace ReSolve
418425
{
419426
Q_ = new index_type[A_->getNumRows()];
420427
size_t nrows = static_cast<size_t>(A_->getNumRows());
421-
std::memcpy(Q_, Symbolic_->Q, nrows * sizeof(index_type));
428+
std::memcpy(Q_, Numeric_->Pnum, nrows * sizeof(index_type)); // KLU's CSC Numeric_->Pnum is the CSR Q vector.
429+
// A numeric factorization is needed to get Pnum, because there is row pivoting for the numeric factorization.
422430
return Q_;
423431
}
424432
else

0 commit comments

Comments
 (0)