Skip to content

Commit fb4d7c3

Browse files
committed
Check ndim in check_trailing_shape helpers
This was previously checked by using an `array_view<type, ndim>`, but moving to pybind11 we won't have that until cast to `unchecked`.
1 parent 54d718e commit fb4d7c3

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/mplutils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ inline int prepare_and_add_type(PyTypeObject *type, PyObject *module)
7171
template<typename T>
7272
inline bool check_trailing_shape(T array, char const* name, long d1)
7373
{
74+
if (array.ndim() != 2) {
75+
PyErr_Format(PyExc_ValueError,
76+
"Expected 2-dimensional array, got %ld",
77+
array.ndim());
78+
return false;
79+
}
7480
if (array.shape(1) != d1) {
7581
PyErr_Format(PyExc_ValueError,
7682
"%s must have shape (N, %ld), got (%ld, %ld)",
@@ -83,6 +89,12 @@ inline bool check_trailing_shape(T array, char const* name, long d1)
8389
template<typename T>
8490
inline bool check_trailing_shape(T array, char const* name, long d1, long d2)
8591
{
92+
if (array.ndim() != 3) {
93+
PyErr_Format(PyExc_ValueError,
94+
"Expected 3-dimensional array, got %ld",
95+
array.ndim());
96+
return false;
97+
}
8698
if (array.shape(1) != d1 || array.shape(2) != d2) {
8799
PyErr_Format(PyExc_ValueError,
88100
"%s must have shape (N, %ld, %ld), got (%ld, %ld, %ld)",

src/numpy_cpp.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,10 +365,6 @@ class array_view : public detail::array_view_accessors<array_view, T, ND>
365365
public:
366366
typedef T value_type;
367367

368-
enum {
369-
ndim = ND
370-
};
371-
372368
array_view() : m_arr(NULL), m_data(NULL)
373369
{
374370
m_shape = zeros;
@@ -492,6 +488,10 @@ class array_view : public detail::array_view_accessors<array_view, T, ND>
492488
return true;
493489
}
494490

491+
npy_intp ndim() const {
492+
return ND;
493+
}
494+
495495
npy_intp shape(size_t i) const
496496
{
497497
if (i >= ND) {

0 commit comments

Comments
 (0)