Skip to content

Commit 7d8cdf3

Browse files
committed
python : expose more attributes of BlockMatrix class
1 parent 926960f commit 7d8cdf3

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

bindings/python/include/aligator/python/blk-matrix.hpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,31 @@
1+
/// @copyright Copyright (C) 2023-2024 LAAS-CNRS, 2023-2025 INRIA
2+
/// @author Wilson Jallet
13
#include "fwd.hpp"
24
#include "aligator/core/blk-matrix.hpp"
35

46
namespace aligator {
57
namespace python {
68
namespace bp = boost::python;
79

10+
template <typename T> struct PrintableVisitor;
11+
812
template <typename BlockMatrixType> struct BlkMatrixPythonVisitor;
913

1014
template <typename MatrixType, int N, int M>
1115
struct BlkMatrixPythonVisitor<BlkMatrix<MatrixType, N, M>>
1216
: bp::def_visitor<BlkMatrixPythonVisitor<BlkMatrix<MatrixType, N, M>>> {
1317
using BlockMatrixType = BlkMatrix<MatrixType, N, M>;
1418
using RefType = Eigen::Ref<MatrixType>;
19+
static constexpr bool IsVector = BlockMatrixType::IsVectorAtCompileTime;
1520

1621
using Self = BlkMatrixPythonVisitor<BlockMatrixType>;
1722

1823
static RefType get_block(BlockMatrixType &bmt, size_t i, size_t j) {
1924
return bmt(i, j);
2025
}
2126

27+
static RefType get_block2(BlockMatrixType &bmt, size_t i) { return bmt[i]; }
28+
2229
static RefType blockRow(BlockMatrixType &mat, size_t i) {
2330
if (i >= mat.rowDims().size()) {
2431
PyErr_SetString(PyExc_IndexError, "Index out of range.");
@@ -27,19 +34,34 @@ struct BlkMatrixPythonVisitor<BlkMatrix<MatrixType, N, M>>
2734
return mat.blockRow(i);
2835
}
2936

37+
static RefType blockCol(BlockMatrixType &mat, size_t i) {
38+
if (i >= mat.rowDims().size()) {
39+
PyErr_SetString(PyExc_IndexError, "Index out of range.");
40+
bp::throw_error_already_set();
41+
}
42+
return mat.blockCol(i);
43+
}
44+
3045
template <class... Args> void visit(bp::class_<Args...> &obj) const {
3146
obj.add_property(
3247
"matrix", +[](BlockMatrixType &m) -> RefType { return m.matrix(); })
33-
.def_readonly("rows", &BlockMatrixType::rows)
34-
.def_readonly("cols", &BlockMatrixType::cols)
48+
.add_property("rows", &BlockMatrixType::rows)
49+
.add_property("cols", &BlockMatrixType::cols)
3550
.add_property("rowDims",
3651
bp::make_function(&BlockMatrixType::rowDims,
3752
bp::return_internal_reference<>()))
3853
.add_property("colDims",
3954
bp::make_function(&BlockMatrixType::colDims,
4055
bp::return_internal_reference<>()))
4156
.def("blockRow", blockRow, "Get a block row by index.")
42-
.def("__call__", get_block, ("self"_a, "i", "j"));
57+
.def("blockCol", blockCol, "Get a block row by index.")
58+
.def("setZero", &BlockMatrixType::setZero, ("self"_a),
59+
"Set all coefficients to zero.")
60+
.def("__call__", get_block, ("self"_a, "i", "j"))
61+
.def(PrintableVisitor<BlockMatrixType>{});
62+
if constexpr (BlockMatrixType::IsVectorAtCompileTime) {
63+
obj.def("__call__", get_block2, ("self"_a, "i"));
64+
}
4365
}
4466

4567
static void expose(const char *name) {

0 commit comments

Comments
 (0)