Skip to content

Commit 5029c05

Browse files
authored
Add validation for numpy array as return type in AddKToMatrix bindings (sofa-framework#541)
* Add validation for numpy array as return type in AddKToMatrix bindings * Refactor AddKToMatrix return type validation and simplify conditional checks
1 parent 8ba1e2c commit 5029c05

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

bindings/Sofa/src/SofaPython3/Sofa/Core/Binding_ForceField.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -147,36 +147,39 @@ namespace sofapython3
147147

148148
py::object ret = _addKToMatrix(mparams, nNodes, nDofs);
149149

150+
if(!py::isinstance<py::array>(ret))
151+
{
152+
throw py::type_error("Can't read return value of AddKToMatrix. A numpy array is expected");
153+
}
154+
150155
// if ret is numpy array
151-
if(py::isinstance<py::array>(ret))
156+
auto r = py::cast<py::array>(ret);
157+
if (r.ndim() == 3 && r.shape(2) == 1)
152158
{
153-
auto r = py::cast<py::array>(ret);
154-
if (r.ndim() == 3 && r.shape(2) == 1)
159+
// read K as a plain 2D matrix
160+
auto kMatrix = r.unchecked<double, 3>();
161+
for (size_t x = 0 ; x < size_t(kMatrix.shape(0)) ; ++x)
155162
{
156-
// read K as a plain 2D matrix
157-
auto kMatrix = r.unchecked<double, 3>();
158-
for (size_t x = 0 ; x < size_t(kMatrix.shape(0)) ; ++x)
163+
for (size_t y = 0 ; y < size_t(kMatrix.shape(1)) ; ++y)
159164
{
160-
for (size_t y = 0 ; y < size_t(kMatrix.shape(1)) ; ++y)
161-
{
162-
mat->add(int(offset + x), int(offset + y), kMatrix(x,y, 0));
163-
}
165+
mat->add(int(offset + x), int(offset + y), kMatrix(x,y, 0));
164166
}
165167
}
166-
else if (r.ndim() == 2 && r.shape(1) == 3)
167-
{
168-
// consider ret to be a list of tuples [(i,j,[val])]
169-
auto kMatrix = r.unchecked<double, 2>();
170-
for (auto x = 0 ; x < kMatrix.shape(0) ; ++x)
171-
{
172-
mat->add(int(offset + size_t(kMatrix(x,0))), int(offset + size_t(kMatrix(x,1))), kMatrix(x,2));
173-
}
174-
}
175-
else
168+
}
169+
else if (r.ndim() == 2 && r.shape(1) == 3)
170+
{
171+
// consider ret to be a list of tuples [(i,j,[val])]
172+
auto kMatrix = r.unchecked<double, 2>();
173+
for (auto x = 0 ; x < kMatrix.shape(0) ; ++x)
176174
{
177-
throw py::type_error("Can't read return value of AddKToMatrix. The method should return either a plain 2D matrix or a vector of tuples (i, j, val)");
175+
mat->add(int(offset + size_t(kMatrix(x,0))), int(offset + size_t(kMatrix(x,1))), kMatrix(x,2));
178176
}
179177
}
178+
else
179+
{
180+
throw py::type_error("Can't read return value of AddKToMatrix. The method should return either a plain 2D matrix or a vector of tuples (i, j, val)");
181+
}
182+
180183
}
181184

182185

0 commit comments

Comments
 (0)