|
1 | 1 | #include "LinSolverDirectCuSolverRf.hpp" |
2 | 2 |
|
| 3 | +#include <algorithm> |
3 | 4 | #include <cassert> |
| 5 | +#include <cstring> // includes memcpy |
| 6 | +#include <vector> |
4 | 7 |
|
5 | 8 | #include <resolve/matrix/Csc.hpp> |
6 | 9 | #include <resolve/matrix/Csr.hpp> |
@@ -44,6 +47,110 @@ namespace ReSolve |
44 | 47 | mem_.deleteOnDevice(d_T_); |
45 | 48 | } |
46 | 49 |
|
| 50 | + /** |
| 51 | + * @brief Setup the cuSolverRf factorization with factors already in CSR |
| 52 | + * |
| 53 | + * Sets up the cuSolverRf factorization for the given matrix A and its |
| 54 | + * L and U factors. The permutation vectors P and Q are also set up. |
| 55 | + * |
| 56 | + * @param[in] A - pointer to the matrix A |
| 57 | + * @param[in] L - pointer to the lower triangular factor L in CSR |
| 58 | + * @param[in] U - pointer to the upper triangular factor U in CSR |
| 59 | + * @param[in] P - pointer to the permutation vector P |
| 60 | + * @param[in] Q - pointer to the permutation vector Q |
| 61 | + * @param[in] rhs - pointer to the right-hand side vector (optional) |
| 62 | + * |
| 63 | + * @pre The matrix A is in CSR format. |
| 64 | + */ |
| 65 | + |
| 66 | + int LinSolverDirectCuSolverRf::setupCsr(matrix::Sparse* A, |
| 67 | + matrix::Sparse* L, |
| 68 | + matrix::Sparse* U, |
| 69 | + index_type* P, |
| 70 | + index_type* Q, |
| 71 | + vector_type* /* rhs */) |
| 72 | + { |
| 73 | + assert(A->getSparseFormat() == matrix::Sparse::COMPRESSED_SPARSE_ROW && "Matrix A has to be in CSR format for cusolverRf input.\n"); |
| 74 | + assert(L->getSparseFormat() == U->getSparseFormat() && "Matrices L and U have to be in the same format for cusolverRf input.\n"); |
| 75 | + assert(L->getSparseFormat() == matrix::Sparse::COMPRESSED_SPARSE_ROW && "Matrices L and U have to be in CSR format for cusolverRf input.\n"); |
| 76 | + int error_sum = 0; |
| 77 | + this->A_ = A; |
| 78 | + index_type n = A_->getNumRows(); |
| 79 | + |
| 80 | + // Remember - P and Q are generally CPU variables! |
| 81 | + // Factorization data is stored in the handle. |
| 82 | + // If function is called again, destroy the old handle to get rid of old data. |
| 83 | + if (setup_completed_) |
| 84 | + { |
| 85 | + cusolverRfDestroy(handle_cusolverrf_); |
| 86 | + cusolverRfCreate(&handle_cusolverrf_); |
| 87 | + } |
| 88 | + |
| 89 | + if (d_P_ == nullptr) |
| 90 | + { |
| 91 | + mem_.allocateArrayOnDevice(&d_P_, n); |
| 92 | + } |
| 93 | + |
| 94 | + if (d_Q_ == nullptr) |
| 95 | + { |
| 96 | + mem_.allocateArrayOnDevice(&d_Q_, n); |
| 97 | + } |
| 98 | + |
| 99 | + if (d_T_ != nullptr) |
| 100 | + { |
| 101 | + mem_.deleteOnDevice(d_T_); |
| 102 | + } |
| 103 | + |
| 104 | + mem_.allocateArrayOnDevice(&d_T_, n); |
| 105 | + |
| 106 | + mem_.copyArrayHostToDevice(d_P_, P, n); |
| 107 | + mem_.copyArrayHostToDevice(d_Q_, Q, n); |
| 108 | + |
| 109 | + status_cusolverrf_ = cusolverRfSetResetValuesFastMode(handle_cusolverrf_, CUSOLVERRF_RESET_VALUES_FAST_MODE_ON); |
| 110 | + error_sum += status_cusolverrf_; |
| 111 | + // sort L and U columns |
| 112 | + for (index_type i = 0; i < n; ++i) |
| 113 | + { |
| 114 | + std::sort(L->getColData(memory::HOST) + L->getRowData(memory::HOST)[i], |
| 115 | + L->getColData(memory::HOST) + L->getRowData(memory::HOST)[i + 1]); |
| 116 | + std::sort(U->getColData(memory::HOST) + U->getRowData(memory::HOST)[i], |
| 117 | + U->getColData(memory::HOST) + U->getRowData(memory::HOST)[i + 1]); |
| 118 | + } |
| 119 | + L->setUpdated(memory::HOST); |
| 120 | + U->setUpdated(memory::HOST); |
| 121 | + L->syncData(memory::DEVICE); |
| 122 | + U->syncData(memory::DEVICE); |
| 123 | + status_cusolverrf_ = cusolverRfSetupDevice(n, |
| 124 | + A_->getNnz(), |
| 125 | + A_->getRowData(memory::DEVICE), |
| 126 | + A_->getColData(memory::DEVICE), |
| 127 | + A_->getValues(memory::DEVICE), |
| 128 | + L->getNnz(), |
| 129 | + L->getRowData(memory::DEVICE), |
| 130 | + L->getColData(memory::DEVICE), |
| 131 | + L->getValues(memory::DEVICE), |
| 132 | + U->getNnz(), |
| 133 | + U->getRowData(memory::DEVICE), |
| 134 | + U->getColData(memory::DEVICE), |
| 135 | + U->getValues(memory::DEVICE), |
| 136 | + d_P_, |
| 137 | + d_Q_, |
| 138 | + handle_cusolverrf_); |
| 139 | + error_sum += status_cusolverrf_; |
| 140 | + mem_.deviceSynchronize(); |
| 141 | + status_cusolverrf_ = cusolverRfAnalyze(handle_cusolverrf_); |
| 142 | + error_sum += status_cusolverrf_; |
| 143 | + const cusolverRfFactorization_t fact_alg = |
| 144 | + CUSOLVERRF_FACTORIZATION_ALG0; // 0 - default, 1 or 2 |
| 145 | + const cusolverRfTriangularSolve_t solve_alg = |
| 146 | + CUSOLVERRF_TRIANGULAR_SOLVE_ALG1; // 1- default, 2 or 3 |
| 147 | + |
| 148 | + this->setAlgorithms(fact_alg, solve_alg); |
| 149 | + |
| 150 | + setup_completed_ = true; |
| 151 | + return error_sum; |
| 152 | + } |
| 153 | + |
47 | 154 | /** |
48 | 155 | * @brief Setup the cuSolverRf factorization |
49 | 156 | * |
@@ -90,25 +197,26 @@ namespace ReSolve |
90 | 197 | switch (L->getSparseFormat()) |
91 | 198 | { |
92 | 199 | case matrix::Sparse::COMPRESSED_SPARSE_COLUMN: |
93 | | - // std::cout << "converting L and U factors from CSC to CSR format ...\n"; |
94 | 200 | L_csc = static_cast<matrix::Csc*>(L); |
95 | 201 | U_csc = static_cast<matrix::Csc*>(U); |
96 | 202 | L_csr = new matrix::Csr(L_csc->getNumRows(), L_csc->getNumColumns(), L_csc->getNnz()); |
97 | 203 | U_csr = new matrix::Csr(U_csc->getNumRows(), U_csc->getNumColumns(), U_csc->getNnz()); |
98 | 204 | csc2csr(L_csc, L_csr); |
99 | 205 | csc2csr(U_csc, U_csr); |
100 | | - L_csr->syncData(memory::DEVICE); |
101 | | - U_csr->syncData(memory::DEVICE); |
102 | 206 | break; |
103 | 207 | case matrix::Sparse::COMPRESSED_SPARSE_ROW: |
104 | | - L_csr = dynamic_cast<matrix::Csr*>(L); |
105 | | - U_csr = dynamic_cast<matrix::Csr*>(U); |
| 208 | + L_csr = static_cast<matrix::Csr*>(L); |
| 209 | + U_csr = static_cast<matrix::Csr*>(U); |
| 210 | + L_csr->setUpdated(memory::HOST); |
| 211 | + U_csr->setUpdated(memory::HOST); |
106 | 212 | break; |
107 | 213 | default: |
108 | 214 | out::error() << "Matrix type for L and U factors not recognized!\n"; |
109 | 215 | out::error() << "Refactorization not completed.\n"; |
110 | 216 | return 1; |
111 | 217 | } |
| 218 | + L_csr->syncData(memory::DEVICE); |
| 219 | + U_csr->syncData(memory::DEVICE); |
112 | 220 |
|
113 | 221 | if (d_P_ == nullptr) |
114 | 222 | { |
@@ -182,9 +290,6 @@ namespace ReSolve |
182 | 290 | default: |
183 | 291 | break; |
184 | 292 | } |
185 | | - // delete L_csr; |
186 | | - // delete U_csr; |
187 | | - |
188 | 293 | return error_sum; |
189 | 294 | } |
190 | 295 |
|
@@ -296,6 +401,7 @@ namespace ReSolve |
296 | 401 | A_->getNumRows(), |
297 | 402 | x->getData(memory::DEVICE), |
298 | 403 | A_->getNumRows()); |
| 404 | + x->setDataUpdated(memory::DEVICE); |
299 | 405 | return status_cusolverrf_; |
300 | 406 | } |
301 | 407 |
|
|
0 commit comments