Skip to content

Commit 885edfa

Browse files
authored
Enforce F being 0 and remove duplicate extract factors code
* added check for F!=0 * enforce F being 0 and clean up factor extraction code duplication * Apply pre-commmit fixes * Update resolve/LinSolverDirectKLU.cpp * cleaned up comments * Apply pre-commmit fixes * Update resolve/LinSolverDirectKLU.cpp --------- Co-authored-by: shakedregev <shakedregev@users.noreply.github.com>
1 parent 0edd57e commit 885edfa

File tree

2 files changed

+42
-97
lines changed

2 files changed

+42
-97
lines changed

resolve/LinSolverDirectKLU.cpp

Lines changed: 40 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "LinSolverDirectKLU.hpp"
22

3-
#include <cstring> // includes memcpy
3+
#include <cassert>
4+
#include <cstring>
45

56
#include <resolve/matrix/Csc.hpp>
67
#include <resolve/matrix/Csr.hpp>
@@ -166,6 +167,13 @@ namespace ReSolve
166167
{
167168
return 1;
168169
}
170+
else
171+
{
172+
if (Numeric_->nzoff != 0)
173+
{
174+
assert(0 && "Numeric_->nzoff != 0, this is not supported by ReSolve!");
175+
}
176+
}
169177
return 0;
170178
}
171179

@@ -253,22 +261,12 @@ namespace ReSolve
253261
}
254262

255263
/**
256-
* @brief Get the L factor of the matrix A in compressed sparse row format.
257-
*
258-
* This function extracts the lower triangular factor L from the
259-
* KLU solver's numeric factorization. If the factors have not been
260-
* extracted yet, it allocates memory for L and U factors,
261-
* extracts them from the numeric factorization,
262-
* and sets the updated flag for both factors.
263-
* Otherwise, it returns the already extracted L factor.
264-
* This is because the input matrix is CSR and interpreted as CSC.
265-
* Then the CSC U extraced from klu is actually an L factor of the original matrix,
266-
* if interprted as CSR. The reverse is true for the CSC L being a U factor of the original matrix.
267-
* Note that in this factorization, the scaling is in the L factor, unlike convention.
264+
* @brief Extracts L and U factors from the KLU solver in CSR format, if they have not already been extracted.
268265
*
269-
* @return L factor in compressed sparse row format
266+
* It extracts the factors as $A = U^T L^T$,
267+
* where U^T is the reinterpretation of the CSC U factor as CSR and L^T is the reinterpretation of the CSC L factor as CSR.
270268
*/
271-
matrix::Sparse* LinSolverDirectKLU::getLFactorCsr()
269+
void LinSolverDirectKLU::extractFactorsCsr()
272270
{
273271
if (!factors_extracted_)
274272
{
@@ -303,69 +301,35 @@ namespace ReSolve
303301
(void) ok; // TODO: Check status in ok before setting `factors_extracted_`
304302
factors_extracted_ = true;
305303
}
304+
return;
305+
}
306+
307+
/**
308+
* @brief Gets an L factor of the matrix A in compressed sparse row format.
309+
*
310+
* @return L factor in compressed sparse row format
311+
*/
312+
matrix::Sparse* LinSolverDirectKLU::getLFactorCsr()
313+
{
314+
extractFactorsCsr();
306315
return L_;
307316
}
308317

309318
/**
310-
* @brief Get the U factor of the matrix A in compressed sparse row format.
311-
*
312-
* This function extracts the upper triangular factor U from the
313-
* KLU solver's numeric factorization. If the factors have not been
314-
* extracted yet, it allocates memory for L and U factors,
315-
* extracts them from the numeric factorization,
316-
* and sets the updated flag for both factors.
317-
* Otherwise, it returns the already extracted U factor.
318-
* This is because the input matrix is CSR and interpreted as CSC.
319-
* Then the CSC U extraced from klu is actually an L factor of the original matrix,
320-
* if interprted as CSR. The reverse is true for the CSC L being a U factor of the original matrix.
321-
* Note that in this factorization, the scaling is in the L factor, unlike convention.
319+
* @brief Gets a U factor of the matrix A in compressed sparse row format.
322320
*
323321
* @return U factor in compressed sparse row format
324322
*/
325323
matrix::Sparse* LinSolverDirectKLU::getUFactorCsr()
326324
{
327-
if (!factors_extracted_)
328-
{
329-
const int nnzL = Numeric_->lnz;
330-
const int nnzU = Numeric_->unz;
331-
332-
// Create CSR matrices - L gets U's data, U gets L's data
333-
L_ = new matrix::Csr(A_->getNumRows(), A_->getNumColumns(), nnzU);
334-
U_ = new matrix::Csr(A_->getNumRows(), A_->getNumColumns(), nnzL);
335-
L_->allocateMatrixData(memory::HOST);
336-
U_->allocateMatrixData(memory::HOST);
337-
338-
int ok = klu_extract(Numeric_,
339-
Symbolic_,
340-
U_->getRowData(memory::HOST), // L CSC colptr -> U CSR rowptr
341-
U_->getColData(memory::HOST), // L CSC rowidx -> U CSR colidx
342-
U_->getValues(memory::HOST), // L CSC values -> U CSR values
343-
L_->getRowData(memory::HOST), // U CSC colptr -> L CSR rowptr
344-
L_->getColData(memory::HOST), // U CSC rowidx -> L CSR colidx
345-
L_->getValues(memory::HOST), // U CSC values -> L CSR values
346-
nullptr,
347-
nullptr,
348-
nullptr,
349-
nullptr,
350-
nullptr,
351-
nullptr,
352-
nullptr,
353-
&Common_);
354-
355-
L_->setUpdated(memory::HOST);
356-
U_->setUpdated(memory::HOST);
357-
(void) ok; // TODO: Check status in ok before setting `factors_extracted_`
358-
factors_extracted_ = true;
359-
}
325+
extractFactorsCsr();
360326
return U_;
361327
}
362328

363329
/**
364-
* @brief Get the lower triangular factor L.
365-
*
366-
* @return L factor
330+
* @brief Extract L and U factors from the KLU solver in compressed sparse column format.
367331
*/
368-
matrix::Sparse* LinSolverDirectKLU::getLFactor()
332+
void LinSolverDirectKLU::extractFactors()
369333
{
370334
if (!factors_extracted_)
371335
{
@@ -399,48 +363,27 @@ namespace ReSolve
399363
(void) ok; // TODO: Check status in ok before setting `factors_extracted_`
400364
factors_extracted_ = true;
401365
}
366+
}
367+
368+
/**
369+
* @brief Get the lower triangular factor L of the matrix A in compressed sparse column format.
370+
*
371+
* @return L factor in compressed sparse column format
372+
*/
373+
matrix::Sparse* LinSolverDirectKLU::getLFactor()
374+
{
375+
extractFactors();
402376
return L_;
403377
}
404378

405379
/**
406-
* @brief Get the upper triangular factor U.
380+
* @brief Get the upper triangular factor U of the matrix A in compressed sparse column format.
407381
*
408382
* @return U factor
409383
*/
410384
matrix::Sparse* LinSolverDirectKLU::getUFactor()
411385
{
412-
if (!factors_extracted_)
413-
{
414-
const int nnzL = Numeric_->lnz;
415-
const int nnzU = Numeric_->unz;
416-
417-
L_ = new matrix::Csc(A_->getNumRows(), A_->getNumColumns(), nnzL);
418-
U_ = new matrix::Csc(A_->getNumRows(), A_->getNumColumns(), nnzU);
419-
L_->allocateMatrixData(memory::HOST);
420-
U_->allocateMatrixData(memory::HOST);
421-
int ok = klu_extract(Numeric_,
422-
Symbolic_,
423-
L_->getColData(memory::HOST),
424-
L_->getRowData(memory::HOST),
425-
L_->getValues(memory::HOST),
426-
U_->getColData(memory::HOST),
427-
U_->getRowData(memory::HOST),
428-
U_->getValues(memory::HOST),
429-
nullptr,
430-
nullptr,
431-
nullptr,
432-
nullptr,
433-
nullptr,
434-
nullptr,
435-
nullptr,
436-
&Common_);
437-
438-
L_->setUpdated(memory::HOST);
439-
U_->setUpdated(memory::HOST);
440-
441-
(void) ok; // TODO: Check status in ok before setting `factors_extracted_`
442-
factors_extracted_ = true;
443-
}
386+
extractFactors();
444387
return U_;
445388
}
446389

resolve/LinSolverDirectKLU.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ namespace ReSolve
3939
int solve(vector_type* rhs, vector_type* x) override;
4040
int solve(vector_type* x) override;
4141

42+
void extractFactorsCsr();
4243
matrix::Sparse* getLFactorCsr() override;
4344
matrix::Sparse* getUFactorCsr() override;
45+
void extractFactors();
4446
matrix::Sparse* getLFactor() override;
4547
matrix::Sparse* getUFactor() override;
4648
index_type* getPOrdering() override;

0 commit comments

Comments
 (0)