1+ // / @copyright Copyright (C) 2023-2024 LAAS-CNRS, 2023-2025 INRIA
2+ // / @author Wilson Jallet
13#include " fwd.hpp"
24#include " aligator/core/blk-matrix.hpp"
35
46namespace aligator {
57namespace python {
68namespace bp = boost::python;
79
10+ template <typename T> struct PrintableVisitor ;
11+
812template <typename BlockMatrixType> struct BlkMatrixPythonVisitor ;
913
1014template <typename MatrixType, int N, int M>
1115struct BlkMatrixPythonVisitor <BlkMatrix<MatrixType, N, M>>
1216 : bp::def_visitor<BlkMatrixPythonVisitor<BlkMatrix<MatrixType, N, M>>> {
1317 using BlockMatrixType = BlkMatrix<MatrixType, N, M>;
1418 using RefType = Eigen::Ref<MatrixType>;
19+ static constexpr bool IsVector = BlockMatrixType::IsVectorAtCompileTime;
1520
1621 using Self = BlkMatrixPythonVisitor<BlockMatrixType>;
1722
1823 static RefType get_block (BlockMatrixType &bmt, size_t i, size_t j) {
1924 return bmt (i, j);
2025 }
2126
27+ static RefType get_block2 (BlockMatrixType &bmt, size_t i) { return bmt[i]; }
28+
2229 static RefType blockRow (BlockMatrixType &mat, size_t i) {
2330 if (i >= mat.rowDims ().size ()) {
2431 PyErr_SetString (PyExc_IndexError, " Index out of range." );
@@ -27,19 +34,34 @@ struct BlkMatrixPythonVisitor<BlkMatrix<MatrixType, N, M>>
2734 return mat.blockRow (i);
2835 }
2936
37+ static RefType blockCol (BlockMatrixType &mat, size_t i) {
38+ if (i >= mat.rowDims ().size ()) {
39+ PyErr_SetString (PyExc_IndexError, " Index out of range." );
40+ bp::throw_error_already_set ();
41+ }
42+ return mat.blockCol (i);
43+ }
44+
3045 template <class ... Args> void visit (bp::class_<Args...> &obj) const {
3146 obj.add_property (
3247 " matrix" , +[](BlockMatrixType &m) -> RefType { return m.matrix (); })
33- .def_readonly (" rows" , &BlockMatrixType::rows)
34- .def_readonly (" cols" , &BlockMatrixType::cols)
48+ .add_property (" rows" , &BlockMatrixType::rows)
49+ .add_property (" cols" , &BlockMatrixType::cols)
3550 .add_property (" rowDims" ,
3651 bp::make_function (&BlockMatrixType::rowDims,
3752 bp::return_internal_reference<>()))
3853 .add_property (" colDims" ,
3954 bp::make_function (&BlockMatrixType::colDims,
4055 bp::return_internal_reference<>()))
4156 .def (" blockRow" , blockRow, " Get a block row by index." )
42- .def (" __call__" , get_block, (" self" _a, " i" , " j" ));
57+ .def (" blockCol" , blockCol, " Get a block row by index." )
58+ .def (" setZero" , &BlockMatrixType::setZero, (" self" _a),
59+ " Set all coefficients to zero." )
60+ .def (" __call__" , get_block, (" self" _a, " i" , " j" ))
61+ .def (PrintableVisitor<BlockMatrixType>{});
62+ if constexpr (BlockMatrixType::IsVectorAtCompileTime) {
63+ obj.def (" __call__" , get_block2, (" self" _a, " i" ));
64+ }
4365 }
4466
4567 static void expose (const char *name) {
0 commit comments