@@ -10,17 +10,97 @@ namespace nanoeigenpy {
1010namespace nb = nanobind;
1111using namespace nb ::literals;
1212
13+ // template<typename SparseQRType>
14+ // class SparseQRMatrixQReturnTypeWrapper {
15+ // private:
16+ // Eigen::SparseQRMatrixQReturnType<SparseQRType> m_q_expr;
17+
18+ // public:
19+ // explicit SparseQRMatrixQReturnTypeWrapper(const SparseQRType& qr)
20+ // : m_q_expr(qr) {}
21+
22+ // Eigen::Index rows() const { return m_q_expr.rows(); }
23+ // Eigen::Index cols() const { return m_q_expr.cols(); }
24+
25+ // Eigen::VectorXd multiply_vec(const Eigen::VectorXd& vec) {
26+ // return Eigen::VectorXd(m_q_expr * vec);
27+ // }
28+
29+ // Eigen::MatrixXd multiply_mat(const Eigen::MatrixXd& mat) {
30+ // return Eigen::MatrixXd(m_q_expr * mat);
31+ // }
32+
33+ // auto adjoint() const {
34+ // return
35+ // SparseQRMatrixQTransposeReturnTypeWrapper<SparseQRType>(m_q_expr);
36+ // }
37+
38+ // auto transpose() const {
39+ // return
40+ // SparseQRMatrixQTransposeReturnTypeWrapper<SparseQRType>(m_q_expr);
41+ // }
42+ // };
43+
44+ // template<typename SparseQRType>
45+ // class SparseQRMatrixQTransposeReturnTypeWrapper {
46+ // private:
47+ // Eigen::SparseQRMatrixQTransposeReturnType<SparseQRType> m_qt_expr;
48+
49+ // public:
50+ // explicit SparseQRMatrixQTransposeReturnTypeWrapper(const SparseQRType&
51+ // qt)
52+ // : m_qt_expr(qt) {}
53+
54+ // Eigen::VectorXd multiply_vec(const Eigen::VectorXd& vec) {
55+ // return Eigen::VectorXd(m_qt_expr * vec);
56+ // }
57+
58+ // Eigen::MatrixXd multiply_mat(const Eigen::MatrixXd& mat) {
59+ // return Eigen::MatrixXd(m_qt_expr * mat);
60+ // }
61+ // };
62+
1363template <typename _MatrixType, typename _Ordering = Eigen::COLAMDOrdering<
1464 typename _MatrixType::StorageIndex>>
1565void exposeSparseQR (nb::module_ m, const char *name) {
1666 using MatrixType = _MatrixType;
1767 using Ordering = _Ordering;
1868 using Solver = Eigen::SparseQR<MatrixType, Ordering>;
69+ using Scalar = typename MatrixType::Scalar;
1970 using RealScalar = typename MatrixType::RealScalar;
71+ using QRMatrixType = Eigen::SparseMatrix<Scalar, Eigen::ColMajor,
72+ typename MatrixType::StorageIndex>;
73+ // using QWrapper = SparseQRMatrixQReturnTypeWrapper<Solver>;
74+ // using QTWrapper = SparseQRMatrixQTransposeReturnTypeWrapper<Solver>;
2075
2176 if (check_registration_alias<Solver>(m)) {
2277 return ;
2378 }
79+
80+ // nb::class_<QWrapper>(m, "SparseQRMatrixQ")
81+ // .def("rows", &QWrapper::rows)
82+ // .def("cols", &QWrapper::cols)
83+ // .def("__mul__", [](const QWrapper& self, const Eigen::Ref<const
84+ // Eigen::VectorXd>& vec) {
85+ // return self.multiply_vec(vec);
86+ // }, "vec"_a)
87+ // .def("__mul__", [](const QWrapper& self, const Eigen::Ref<const
88+ // Eigen::MatrixXd>& mat) {
89+ // return self.multiply_mat(mat);
90+ // }, "mat"_a)
91+ // .def("adjoint", &QWrapper::adjoint)
92+ // .def("transpose", &QWrapper::transpose);
93+
94+ // nb::class_<QTWrapper>(m, "SparseQRMatrixQTranspose")
95+ // .def("__mul__", [](const QTWrapper& self, const Eigen::Ref<const
96+ // Eigen::VectorXd>& vec) {
97+ // return self.multiply_vec(vec);
98+ // }, "vec"_a)
99+ // .def("__mul__", [](const QTWrapper& self, const Eigen::Ref<const
100+ // Eigen::MatrixXd>& mat) {
101+ // return self.multiply_mat(mat);
102+ // }, "mat"_a);
103+
24104 nb::class_<Solver>(
25105 m, name,
26106 " Sparse left-looking QR factorization with numerical column pivoting. "
@@ -67,9 +147,16 @@ void exposeSparseQR(nb::module_ m, const char *name) {
67147 " The input matrix should be in compressed mode "
68148 " (see SparseMatrix::makeCompressed())." )
69149
70- // TODO: Expose so that the return type are convertible to np arrays
71- // matrixQ()
72- // matrixR()
150+ // .def("matrixQ", [](const Solver& self) -> QWrapper {
151+ // return QWrapper(self);
152+ // }, "Returns an expression of the matrix Q")
153+
154+ .def (
155+ " matrixR" ,
156+ [](Solver &self) -> const QRMatrixType & { return self.matrixR (); },
157+ " Returns a const reference to the \b sparse upper triangular matrix "
158+ " R of the QR factorization." ,
159+ nb::rv_policy::reference_internal)
73160
74161 .def (" rows" , &Solver::rows, " Returns the number of rows of the matrix." )
75162 .def (" cols" , &Solver::cols, " Returns the number of cols of the matrix." )
0 commit comments