@@ -10,20 +10,74 @@ namespace nanoeigenpy {
1010namespace nb = nanobind;
1111using namespace nb ::literals;
1212
13+ template <typename LTypeOrUType, typename MatrixOrVector>
14+ static void solveInPlace (const LTypeOrUType &self,
15+ Eigen::Ref<MatrixOrVector> mat_vec) {
16+ self.solveInPlace (mat_vec);
17+ }
18+
19+ template <typename MappedSupernodalType>
20+ void exposeMatrixL (nb::module_ m) {
21+ using LType = Eigen::SparseLUMatrixLReturnType<MappedSupernodalType>;
22+ using Scalar = typename MappedSupernodalType::Scalar;
23+ using VectorXs = Eigen::Matrix<Scalar, Eigen::Dynamic, 1 , Eigen::ColMajor>;
24+ using MatrixXs =
25+ Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>;
26+
27+ if (check_registration_alias<LType>(m)) {
28+ return ;
29+ }
30+
31+ nb::class_<LType>(m, " SparseLUMatrixLReturnType" )
32+ .def (" rows" , <ype::rows)
33+ .def (" cols" , <ype::cols)
34+ .def (" solveInPlace" , &solveInPlace<LType, MatrixXs>, " X" _a)
35+ .def (" solveInPlace" , &solveInPlace<LType, VectorXs>, " x" _a);
36+ }
37+
38+ template <typename MatrixLType, typename MatrixUType>
39+ void exposeMatrixU (nb::module_ m) {
40+ using UType = Eigen::SparseLUMatrixUReturnType<MatrixLType, MatrixUType>;
41+ using Scalar = typename MatrixLType::Scalar;
42+ using VectorXs = Eigen::Matrix<Scalar, Eigen::Dynamic, 1 , Eigen::ColMajor>;
43+ using MatrixXs =
44+ Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>;
45+
46+ if (check_registration_alias<UType>(m)) {
47+ return ;
48+ }
49+
50+ nb::class_<UType>(m, " SparseLUMatrixUReturnType" )
51+ .def (" rows" , &UType::rows)
52+ .def (" cols" , &UType::cols)
53+ .def (" solveInPlace" , &solveInPlace<UType, MatrixXs>, " X" _a)
54+ .def (" solveInPlace" , &solveInPlace<UType, VectorXs>, " x" _a);
55+ }
56+
1357template <typename _MatrixType, typename _Ordering = Eigen::COLAMDOrdering<
1458 typename _MatrixType::StorageIndex>>
1559void exposeSparseLU (nb::module_ m, const char *name) {
1660 using MatrixType = _MatrixType;
1761 using Solver = Eigen::SparseLU<MatrixType>;
62+ using Scalar = typename MatrixType::Scalar;
1863 using RealScalar = typename MatrixType::RealScalar;
64+ using StorageIndex = typename MatrixType::StorageIndex;
1965 using SparseLUTransposeViewTrue = Eigen::SparseLUTransposeView<true , Solver>;
2066 using SparseLUTransposeViewFalse =
2167 Eigen::SparseLUTransposeView<false , Solver>;
68+ using SCMatrix = typename Solver::SCMatrix;
69+ using MappedSparseMatrix =
70+ typename Eigen::MappedSparseMatrix<Scalar, Eigen::ColMajor, StorageIndex>;
71+ using LType = Eigen::SparseLUMatrixLReturnType<SCMatrix>;
72+ using UType = Eigen::SparseLUMatrixUReturnType<SCMatrix, MappedSparseMatrix>;
2273
2374 if (check_registration_alias<Solver>(m)) {
2475 return ;
2576 }
2677
78+ exposeMatrixL<SCMatrix>(m);
79+ exposeMatrixU<SCMatrix, MappedSparseMatrix>(m);
80+
2781 nb::class_<SparseLUTransposeViewFalse>(m, " SparseLUTransposeView" )
2882 .def (SparseSolverBaseVisitor ())
2983 .def (" setIsInitialized" , &SparseLUTransposeViewFalse::setIsInitialized)
@@ -104,6 +158,13 @@ void exposeSparseLU(nb::module_ m, const char *name) {
104158 },
105159 " Returns an expression of the adjoint of the factored matrix." )
106160
161+ .def (
162+ " matrixL" , [](const Solver &self) -> LType { return self.matrixL (); },
163+ " Returns an expression of the matrix L." )
164+ .def (
165+ " matrixU" , [](const Solver &self) -> UType { return self.matrixU (); },
166+ " Returns an expression of the matrix U." )
167+
107168 .def (" rows" , &Solver::rows, " Returns the number of rows of the matrix." )
108169 .def (" cols" , &Solver::cols, " Returns the number of cols of the matrix." )
109170
0 commit comments