@@ -10,21 +10,82 @@ namespace nanoeigenpy {
1010namespace nb = nanobind;
1111using namespace nb ::literals;
1212
13+ template <typename SparseQRType>
14+ void exposeMatrixQ (nb::module_ m) {
15+ using Scalar = typename SparseQRType::Scalar;
16+ using QType = Eigen::SparseQRMatrixQReturnType<SparseQRType>;
17+ using QTransposeType =
18+ Eigen::SparseQRMatrixQTransposeReturnType<SparseQRType>;
19+ using VectorXd = Eigen::VectorXd;
20+ using MatrixXd = Eigen::MatrixXd;
21+ using QRMatrixType = typename SparseQRType::QRMatrixType;
22+
23+ if (!check_registration_alias<QTransposeType>(m)) {
24+ nb::class_<QTransposeType>(m, " SparseQRMatrixQTransposeReturnType" )
25+ .def (nb::init<const SparseQRType&>(), " qr" _a)
26+
27+ .def (
28+ " __matmul__" ,
29+ [](QTransposeType& self, const MatrixXd& other) -> MatrixXd {
30+ return MatrixXd (self * other);
31+ },
32+ " other" _a)
33+
34+ .def (
35+ " __matmul__" ,
36+ [](QTransposeType& self, const VectorXd& other) -> VectorXd {
37+ return VectorXd (self * other);
38+ },
39+ " other" _a);
40+ }
41+
42+ if (!check_registration_alias<QType>(m)) {
43+ nb::class_<QType>(m, " SparseQRMatrixQReturnType" )
44+ .def (nb::init<const SparseQRType&>(), " qr" _a)
45+
46+ .def_prop_ro (" rows" , &QType::rows)
47+ .def_prop_ro (" cols" , &QType::cols)
48+
49+ .def (
50+ " __matmul__" ,
51+ [](QType& self, const MatrixXd& other) -> MatrixXd {
52+ return MatrixXd (self * other);
53+ },
54+ " other" _a)
55+
56+ .def (
57+ " __matmul__" ,
58+ [](QType& self, const VectorXd& other) -> VectorXd {
59+ return VectorXd (self * other);
60+ },
61+ " other" _a)
62+
63+ .def (" adjoint" ,
64+ [](const QType& self) -> QTransposeType { return self.adjoint (); })
65+
66+ .def (" transpose" , [](const QType& self) -> QTransposeType {
67+ return self.transpose ();
68+ });
69+ }
70+ }
71+
1372template <typename _MatrixType, typename _Ordering = Eigen::COLAMDOrdering<
1473 typename _MatrixType::StorageIndex>>
15- void exposeSparseQR (nb::module_ m, const char * name) {
74+ void exposeSparseQR (nb::module_ m, const char * name) {
1675 using MatrixType = _MatrixType;
1776 using Ordering = _Ordering;
1877 using Solver = Eigen::SparseQR<MatrixType, Ordering>;
1978 using Scalar = typename MatrixType::Scalar;
2079 using RealScalar = typename MatrixType::RealScalar;
21- using QRMatrixType = Eigen::SparseMatrix<Scalar, Eigen::ColMajor,
22- typename MatrixType::StorageIndex >;
80+ using QRMatrixType = typename Solver::QRMatrixType;
81+ using QType = Eigen::SparseQRMatrixQReturnType<Solver >;
2382
2483 if (check_registration_alias<Solver>(m)) {
2584 return ;
2685 }
2786
87+ exposeMatrixQ<Solver>(m);
88+
2889 nb::class_<Solver>(
2990 m, name,
3091 " Sparse left-looking QR factorization with numerical column pivoting. "
@@ -51,7 +112,7 @@ void exposeSparseQR(nb::module_ m, const char *name) {
51112 " factor of full rank." )
52113
53114 .def (nb::init<>(), " Default constructor." )
54- .def (nb::init<const MatrixType &>(), " matrix" _a,
115+ .def (nb::init<const MatrixType&>(), " matrix" _a,
55116 " Constructs a LU factorization from a given matrix." )
56117
57118 .def (SparseSolverBaseVisitor ())
@@ -71,9 +132,13 @@ void exposeSparseQR(nb::module_ m, const char *name) {
71132 " The input matrix should be in compressed mode "
72133 " (see SparseMatrix::makeCompressed())." )
73134
135+ .def (
136+ " matrixQ" , [](const Solver& self) -> QType { return self.matrixQ (); },
137+ " Returns an expression of the matrix Q as products of sparse "
138+ " Householder reflectors." )
74139 .def (
75140 " matrixR" ,
76- [](const Solver & self) -> const QRMatrixType & {
141+ [](const Solver& self) -> const QRMatrixType& {
77142 return self.matrixR ();
78143 },
79144 " Returns a const reference to the \b sparse upper triangular matrix "
@@ -99,7 +164,7 @@ void exposeSparseQR(nb::module_ m, const char *name) {
99164
100165 .def (
101166 " setPivotThreshold" ,
102- [](Solver & self, const RealScalar & thresh) -> void {
167+ [](Solver& self, const RealScalar& thresh) -> void {
103168 return self.setPivotThreshold (thresh);
104169 },
105170 " Set the threshold used for a diagonal entry to be an acceptable "
0 commit comments