Skip to content

Commit 0594c97

Browse files
committed
decompositions: SparseLU and SparseQR: Tried to expose matrixU, matrixL, matrixQ without success yet
1 parent d16d664 commit 0594c97

File tree

2 files changed

+154
-8
lines changed

2 files changed

+154
-8
lines changed

include/nanoeigenpy/decompositions/sparse/sparse-lu.hpp

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,42 @@ void exposeSparseLU(nb::module_ m, const char *name) {
1616
using MatrixType = _MatrixType;
1717
using Solver = Eigen::SparseLU<MatrixType>;
1818
using RealScalar = typename MatrixType::RealScalar;
19+
using SparseLUTransposeViewTrue = Eigen::SparseLUTransposeView<true, Solver>;
20+
using SparseLUTransposeViewFalse =
21+
Eigen::SparseLUTransposeView<false, Solver>;
22+
// using SCMatrix = typename Solver::SCMatrix;
23+
// using MappedSparseMatrixType = Eigen::MappedSparseMatrix<typename
24+
// Solver::Scalar, Eigen::ColMajor, typename Solver::StorageIndex>; using
25+
// SparseLUMatrixLType = Eigen::SparseLUMatrixLReturnType<SCMatrix>; using
26+
// SparseLUMatrixUType = Eigen::SparseLUMatrixUReturnType<SCMatrix,
27+
// MappedSparseMatrixType>;
1928

2029
if (check_registration_alias<Solver>(m)) {
2130
return;
2231
}
32+
33+
nb::class_<SparseLUTransposeViewFalse>(m, "SparseLUTransposeView")
34+
.def(SparseSolverBaseVisitor())
35+
.def("setIsInitialized", &SparseLUTransposeViewFalse::setIsInitialized)
36+
.def("setSparseLU", &SparseLUTransposeViewFalse::setSparseLU)
37+
.def("rows", &SparseLUTransposeViewFalse::rows)
38+
.def("cols", &SparseLUTransposeViewFalse::cols);
39+
40+
nb::class_<SparseLUTransposeViewTrue>(m, "SparseLUAdjointView")
41+
.def(SparseSolverBaseVisitor())
42+
.def("setIsInitialized", &SparseLUTransposeViewTrue::setIsInitialized)
43+
.def("setSparseLU", &SparseLUTransposeViewTrue::setSparseLU)
44+
.def("rows", &SparseLUTransposeViewTrue::rows)
45+
.def("cols", &SparseLUTransposeViewTrue::cols);
46+
47+
// nb::class_<SparseLUMatrixLType>(m, "SparseLUMatrixL")
48+
// .def("rows", &SparseLUMatrixLType::rows)
49+
// .def("cols", &SparseLUMatrixLType::cols);
50+
51+
// nb::class_<SparseLUMatrixUType>(m, "SparseLUMatrixU")
52+
// .def("rows", &SparseLUMatrixUType::rows)
53+
// .def("cols", &SparseLUMatrixUType::cols);
54+
2355
nb::class_<Solver>(
2456
m, name,
2557
"Sparse supernodal LU factorization for general matrices.\n\n"
@@ -70,15 +102,42 @@ void exposeSparseLU(nb::module_ m, const char *name) {
70102
"matrix.\n\n"
71103
"The input matrix should be in column-major storage.")
72104

73-
// TODO: Expose so that the return type are convertible to np arrays
74-
// transpose()
75-
// adjoint()
76-
// matrixU()
77-
// matrixL()
105+
.def(
106+
"transpose",
107+
[](Solver &self) -> SparseLUTransposeViewFalse {
108+
auto view = self.transpose();
109+
return view;
110+
},
111+
"Returns an expression of the transposed of the factored matrix.")
112+
113+
.def(
114+
"adjoint",
115+
[](Solver &self) -> SparseLUTransposeViewTrue {
116+
auto view = self.adjoint();
117+
return view;
118+
},
119+
"Returns an expression of the adjoint of the factored matrix.")
78120

79121
.def("rows", &Solver::rows, "Returns the number of rows of the matrix.")
80122
.def("cols", &Solver::cols, "Returns the number of cols of the matrix.")
81123

124+
.def("isSymmetric", &Solver::isSymmetric,
125+
"Indicate that the pattern of the input matrix is symmetric.")
126+
127+
// .def("matrixU",
128+
// [](Solver& self) -> SparseLUMatrixUType {
129+
// auto view = self.matrixU();
130+
// return view;
131+
// },
132+
// "Returns an expression of the matrix U.")
133+
134+
// .def("matrixL",
135+
// [](Solver& self) -> SparseLUMatrixLType {
136+
// auto view = self.matrixL();
137+
// return view;
138+
// },
139+
// "Returns an expression of the matrix L.")
140+
82141
.def("rowsPermutation", &Solver::rowsPermutation,
83142
"Returns a reference to the row matrix permutation "
84143
"\f$ P_r \f$ such that \f$P_r A P_c^T = L U\f$.",

include/nanoeigenpy/decompositions/sparse/sparse-qr.hpp

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,97 @@ namespace nanoeigenpy {
1010
namespace nb = nanobind;
1111
using namespace nb::literals;
1212

13+
// template<typename SparseQRType>
14+
// class SparseQRMatrixQReturnTypeWrapper {
15+
// private:
16+
// Eigen::SparseQRMatrixQReturnType<SparseQRType> m_q_expr;
17+
18+
// public:
19+
// explicit SparseQRMatrixQReturnTypeWrapper(const SparseQRType& qr)
20+
// : m_q_expr(qr) {}
21+
22+
// Eigen::Index rows() const { return m_q_expr.rows(); }
23+
// Eigen::Index cols() const { return m_q_expr.cols(); }
24+
25+
// Eigen::VectorXd multiply_vec(const Eigen::VectorXd& vec) {
26+
// return Eigen::VectorXd(m_q_expr * vec);
27+
// }
28+
29+
// Eigen::MatrixXd multiply_mat(const Eigen::MatrixXd& mat) {
30+
// return Eigen::MatrixXd(m_q_expr * mat);
31+
// }
32+
33+
// auto adjoint() const {
34+
// return
35+
// SparseQRMatrixQTransposeReturnTypeWrapper<SparseQRType>(m_q_expr);
36+
// }
37+
38+
// auto transpose() const {
39+
// return
40+
// SparseQRMatrixQTransposeReturnTypeWrapper<SparseQRType>(m_q_expr);
41+
// }
42+
// };
43+
44+
// template<typename SparseQRType>
45+
// class SparseQRMatrixQTransposeReturnTypeWrapper {
46+
// private:
47+
// Eigen::SparseQRMatrixQTransposeReturnType<SparseQRType> m_qt_expr;
48+
49+
// public:
50+
// explicit SparseQRMatrixQTransposeReturnTypeWrapper(const SparseQRType&
51+
// qt)
52+
// : m_qt_expr(qt) {}
53+
54+
// Eigen::VectorXd multiply_vec(const Eigen::VectorXd& vec) {
55+
// return Eigen::VectorXd(m_qt_expr * vec);
56+
// }
57+
58+
// Eigen::MatrixXd multiply_mat(const Eigen::MatrixXd& mat) {
59+
// return Eigen::MatrixXd(m_qt_expr * mat);
60+
// }
61+
// };
62+
1363
template <typename _MatrixType, typename _Ordering = Eigen::COLAMDOrdering<
1464
typename _MatrixType::StorageIndex>>
1565
void exposeSparseQR(nb::module_ m, const char *name) {
1666
using MatrixType = _MatrixType;
1767
using Ordering = _Ordering;
1868
using Solver = Eigen::SparseQR<MatrixType, Ordering>;
69+
using Scalar = typename MatrixType::Scalar;
1970
using RealScalar = typename MatrixType::RealScalar;
71+
using QRMatrixType = Eigen::SparseMatrix<Scalar, Eigen::ColMajor,
72+
typename MatrixType::StorageIndex>;
73+
// using QWrapper = SparseQRMatrixQReturnTypeWrapper<Solver>;
74+
// using QTWrapper = SparseQRMatrixQTransposeReturnTypeWrapper<Solver>;
2075

2176
if (check_registration_alias<Solver>(m)) {
2277
return;
2378
}
79+
80+
// nb::class_<QWrapper>(m, "SparseQRMatrixQ")
81+
// .def("rows", &QWrapper::rows)
82+
// .def("cols", &QWrapper::cols)
83+
// .def("__mul__", [](const QWrapper& self, const Eigen::Ref<const
84+
// Eigen::VectorXd>& vec) {
85+
// return self.multiply_vec(vec);
86+
// }, "vec"_a)
87+
// .def("__mul__", [](const QWrapper& self, const Eigen::Ref<const
88+
// Eigen::MatrixXd>& mat) {
89+
// return self.multiply_mat(mat);
90+
// }, "mat"_a)
91+
// .def("adjoint", &QWrapper::adjoint)
92+
// .def("transpose", &QWrapper::transpose);
93+
94+
// nb::class_<QTWrapper>(m, "SparseQRMatrixQTranspose")
95+
// .def("__mul__", [](const QTWrapper& self, const Eigen::Ref<const
96+
// Eigen::VectorXd>& vec) {
97+
// return self.multiply_vec(vec);
98+
// }, "vec"_a)
99+
// .def("__mul__", [](const QTWrapper& self, const Eigen::Ref<const
100+
// Eigen::MatrixXd>& mat) {
101+
// return self.multiply_mat(mat);
102+
// }, "mat"_a);
103+
24104
nb::class_<Solver>(
25105
m, name,
26106
"Sparse left-looking QR factorization with numerical column pivoting. "
@@ -67,9 +147,16 @@ void exposeSparseQR(nb::module_ m, const char *name) {
67147
"The input matrix should be in compressed mode "
68148
"(see SparseMatrix::makeCompressed()).")
69149

70-
// TODO: Expose so that the return type are convertible to np arrays
71-
// matrixQ()
72-
// matrixR()
150+
// .def("matrixQ", [](const Solver& self) -> QWrapper {
151+
// return QWrapper(self);
152+
// }, "Returns an expression of the matrix Q")
153+
154+
.def(
155+
"matrixR",
156+
[](Solver &self) -> const QRMatrixType & { return self.matrixR(); },
157+
"Returns a const reference to the \b sparse upper triangular matrix "
158+
"R of the QR factorization.",
159+
nb::rv_policy::reference_internal)
73160

74161
.def("rows", &Solver::rows, "Returns the number of rows of the matrix.")
75162
.def("cols", &Solver::cols, "Returns the number of cols of the matrix.")

0 commit comments

Comments
 (0)