Skip to content

Commit bc041de

Browse files
authored
Fix buffer protocol implementation (#5407)
* Fix buffer protocol implementation According to the buffer protocol, `ndim` is a _required_ field [1], and should always be set correctly. Additionally, `shape` should be set if flags includes `PyBUF_ND` or higher [2]. The current implementation only set those fields if flags was `PyBUF_STRIDES`. [1] https://docs.python.org/3/c-api/buffer.html#request-independent-fields [2] https://docs.python.org/3/c-api/buffer.html#shape-strides-suboffsets * Apply suggestions from review * Obey contiguity requests for buffer protocol If a contiguous buffer is requested, and the underlying buffer isn't, then that should raise. This matches NumPy behaviour if you do something like: ``` struct.unpack_from('5d', np.arange(20.0)[::4]) # Raises for contiguity ``` Also, if a buffer is contiguous, then it can masquerade as a less-complex buffer, either by dropping strides, or even pretending to be 1D. This matches NumPy behaviour if you do something like: ``` a = np.full((3, 5), 30.0) struct.unpack_from('15d', a) # --> Produces 1D tuple from 2D buffer. ``` * Handle review comments * Test buffer protocol against NumPy * Also check PyBUF_FORMAT results
1 parent 75e48c5 commit bc041de

File tree

3 files changed

+382
-5
lines changed

3 files changed

+382
-5
lines changed

include/pybind11/detail/class.h

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -601,24 +601,70 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla
601601
set_error(PyExc_BufferError, "Writable buffer requested for readonly storage");
602602
return -1;
603603
}
604+
605+
// Fill in all the information, and then downgrade as requested by the caller, or raise an
606+
// error if that's not possible.
604607
view->obj = obj;
605-
view->ndim = 1;
606608
view->internal = info;
607609
view->buf = info->ptr;
608610
view->itemsize = info->itemsize;
609611
view->len = view->itemsize;
610612
for (auto s : info->shape) {
611613
view->len *= s;
612614
}
615+
view->ndim = static_cast<int>(info->ndim);
616+
view->shape = info->shape.data();
617+
view->strides = info->strides.data();
613618
view->readonly = static_cast<int>(info->readonly);
614619
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
615620
view->format = const_cast<char *>(info->format.c_str());
616621
}
617-
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
618-
view->ndim = (int) info->ndim;
619-
view->strides = info->strides.data();
620-
view->shape = info->shape.data();
622+
623+
// Note, all contiguity flags imply PyBUF_STRIDES and lower.
624+
if ((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS) {
625+
if (PyBuffer_IsContiguous(view, 'C') == 0) {
626+
std::memset(view, 0, sizeof(Py_buffer));
627+
delete info;
628+
set_error(PyExc_BufferError,
629+
"C-contiguous buffer requested for discontiguous storage");
630+
return -1;
631+
}
632+
} else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS) {
633+
if (PyBuffer_IsContiguous(view, 'F') == 0) {
634+
std::memset(view, 0, sizeof(Py_buffer));
635+
delete info;
636+
set_error(PyExc_BufferError,
637+
"Fortran-contiguous buffer requested for discontiguous storage");
638+
return -1;
639+
}
640+
} else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS) {
641+
if (PyBuffer_IsContiguous(view, 'A') == 0) {
642+
std::memset(view, 0, sizeof(Py_buffer));
643+
delete info;
644+
set_error(PyExc_BufferError, "Contiguous buffer requested for discontiguous storage");
645+
return -1;
646+
}
647+
648+
} else if ((flags & PyBUF_STRIDES) != PyBUF_STRIDES) {
649+
// If no strides are requested, the buffer must be C-contiguous.
650+
// https://docs.python.org/3/c-api/buffer.html#contiguity-requests
651+
if (PyBuffer_IsContiguous(view, 'C') == 0) {
652+
std::memset(view, 0, sizeof(Py_buffer));
653+
delete info;
654+
set_error(PyExc_BufferError,
655+
"C-contiguous buffer requested for discontiguous storage");
656+
return -1;
657+
}
658+
659+
view->strides = nullptr;
660+
661+
// Since this is a contiguous buffer, it can also pretend to be 1D.
662+
if ((flags & PyBUF_ND) != PyBUF_ND) {
663+
view->shape = nullptr;
664+
view->ndim = 0;
665+
}
621666
}
667+
622668
Py_INCREF(view->obj);
623669
return 0;
624670
}

tests/test_buffers.cpp

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,125 @@ TEST_SUBMODULE(buffers, m) {
167167
sizeof(float)});
168168
});
169169

170+
// A matrix that uses Fortran storage order.
171+
class FortranMatrix : public Matrix {
172+
public:
173+
FortranMatrix(py::ssize_t rows, py::ssize_t cols) : Matrix(cols, rows) {
174+
print_created(this,
175+
std::to_string(rows) + "x" + std::to_string(cols) + " Fortran matrix");
176+
}
177+
178+
float operator()(py::ssize_t i, py::ssize_t j) const { return Matrix::operator()(j, i); }
179+
180+
float &operator()(py::ssize_t i, py::ssize_t j) { return Matrix::operator()(j, i); }
181+
182+
using Matrix::data;
183+
184+
py::ssize_t rows() const { return Matrix::cols(); }
185+
py::ssize_t cols() const { return Matrix::rows(); }
186+
};
187+
py::class_<FortranMatrix, Matrix>(m, "FortranMatrix", py::buffer_protocol())
188+
.def(py::init<py::ssize_t, py::ssize_t>())
189+
190+
.def("rows", &FortranMatrix::rows)
191+
.def("cols", &FortranMatrix::cols)
192+
193+
/// Bare bones interface
194+
.def("__getitem__",
195+
[](const FortranMatrix &m, std::pair<py::ssize_t, py::ssize_t> i) {
196+
if (i.first >= m.rows() || i.second >= m.cols()) {
197+
throw py::index_error();
198+
}
199+
return m(i.first, i.second);
200+
})
201+
.def("__setitem__",
202+
[](FortranMatrix &m, std::pair<py::ssize_t, py::ssize_t> i, float v) {
203+
if (i.first >= m.rows() || i.second >= m.cols()) {
204+
throw py::index_error();
205+
}
206+
m(i.first, i.second) = v;
207+
})
208+
/// Provide buffer access
209+
.def_buffer([](FortranMatrix &m) -> py::buffer_info {
210+
return py::buffer_info(m.data(), /* Pointer to buffer */
211+
{m.rows(), m.cols()}, /* Buffer dimensions */
212+
/* Strides (in bytes) for each index */
213+
{sizeof(float), sizeof(float) * size_t(m.rows())});
214+
});
215+
216+
// A matrix that uses a discontiguous underlying memory block.
217+
class DiscontiguousMatrix : public Matrix {
218+
public:
219+
DiscontiguousMatrix(py::ssize_t rows,
220+
py::ssize_t cols,
221+
py::ssize_t row_factor,
222+
py::ssize_t col_factor)
223+
: Matrix(rows * row_factor, cols * col_factor), m_row_factor(row_factor),
224+
m_col_factor(col_factor) {
225+
print_created(this,
226+
std::to_string(rows) + "(*" + std::to_string(row_factor) + ")x"
227+
+ std::to_string(cols) + "(*" + std::to_string(col_factor)
228+
+ ") matrix");
229+
}
230+
231+
~DiscontiguousMatrix() {
232+
print_destroyed(this,
233+
std::to_string(rows() / m_row_factor) + "(*"
234+
+ std::to_string(m_row_factor) + ")x"
235+
+ std::to_string(cols() / m_col_factor) + "(*"
236+
+ std::to_string(m_col_factor) + ") matrix");
237+
}
238+
239+
float operator()(py::ssize_t i, py::ssize_t j) const {
240+
return Matrix::operator()(i * m_row_factor, j * m_col_factor);
241+
}
242+
243+
float &operator()(py::ssize_t i, py::ssize_t j) {
244+
return Matrix::operator()(i * m_row_factor, j * m_col_factor);
245+
}
246+
247+
using Matrix::data;
248+
249+
py::ssize_t rows() const { return Matrix::rows() / m_row_factor; }
250+
py::ssize_t cols() const { return Matrix::cols() / m_col_factor; }
251+
py::ssize_t row_factor() const { return m_row_factor; }
252+
py::ssize_t col_factor() const { return m_col_factor; }
253+
254+
private:
255+
py::ssize_t m_row_factor;
256+
py::ssize_t m_col_factor;
257+
};
258+
py::class_<DiscontiguousMatrix, Matrix>(m, "DiscontiguousMatrix", py::buffer_protocol())
259+
.def(py::init<py::ssize_t, py::ssize_t, py::ssize_t, py::ssize_t>())
260+
261+
.def("rows", &DiscontiguousMatrix::rows)
262+
.def("cols", &DiscontiguousMatrix::cols)
263+
264+
/// Bare bones interface
265+
.def("__getitem__",
266+
[](const DiscontiguousMatrix &m, std::pair<py::ssize_t, py::ssize_t> i) {
267+
if (i.first >= m.rows() || i.second >= m.cols()) {
268+
throw py::index_error();
269+
}
270+
return m(i.first, i.second);
271+
})
272+
.def("__setitem__",
273+
[](DiscontiguousMatrix &m, std::pair<py::ssize_t, py::ssize_t> i, float v) {
274+
if (i.first >= m.rows() || i.second >= m.cols()) {
275+
throw py::index_error();
276+
}
277+
m(i.first, i.second) = v;
278+
})
279+
/// Provide buffer access
280+
.def_buffer([](DiscontiguousMatrix &m) -> py::buffer_info {
281+
return py::buffer_info(m.data(), /* Pointer to buffer */
282+
{m.rows(), m.cols()}, /* Buffer dimensions */
283+
/* Strides (in bytes) for each index */
284+
{size_t(m.col_factor()) * sizeof(float) * size_t(m.cols())
285+
* size_t(m.row_factor()),
286+
size_t(m.col_factor()) * sizeof(float)});
287+
});
288+
170289
class BrokenMatrix : public Matrix {
171290
public:
172291
BrokenMatrix(py::ssize_t rows, py::ssize_t cols) : Matrix(rows, cols) {}
@@ -268,4 +387,56 @@ TEST_SUBMODULE(buffers, m) {
268387
});
269388

270389
m.def("get_buffer_info", [](const py::buffer &buffer) { return buffer.request(); });
390+
391+
// Expose Py_buffer for testing.
392+
m.attr("PyBUF_FORMAT") = PyBUF_FORMAT;
393+
m.attr("PyBUF_SIMPLE") = PyBUF_SIMPLE;
394+
m.attr("PyBUF_ND") = PyBUF_ND;
395+
m.attr("PyBUF_STRIDES") = PyBUF_STRIDES;
396+
m.attr("PyBUF_INDIRECT") = PyBUF_INDIRECT;
397+
m.attr("PyBUF_C_CONTIGUOUS") = PyBUF_C_CONTIGUOUS;
398+
m.attr("PyBUF_F_CONTIGUOUS") = PyBUF_F_CONTIGUOUS;
399+
m.attr("PyBUF_ANY_CONTIGUOUS") = PyBUF_ANY_CONTIGUOUS;
400+
401+
m.def("get_py_buffer", [](const py::object &object, int flags) {
402+
Py_buffer buffer;
403+
memset(&buffer, 0, sizeof(Py_buffer));
404+
if (PyObject_GetBuffer(object.ptr(), &buffer, flags) == -1) {
405+
throw py::error_already_set();
406+
}
407+
408+
auto SimpleNamespace = py::module_::import("types").attr("SimpleNamespace");
409+
py::object result = SimpleNamespace("len"_a = buffer.len,
410+
"readonly"_a = buffer.readonly,
411+
"itemsize"_a = buffer.itemsize,
412+
"format"_a = buffer.format,
413+
"ndim"_a = buffer.ndim,
414+
"shape"_a = py::none(),
415+
"strides"_a = py::none(),
416+
"suboffsets"_a = py::none());
417+
if (buffer.shape != nullptr) {
418+
py::list l;
419+
for (auto i = 0; i < buffer.ndim; i++) {
420+
l.append(buffer.shape[i]);
421+
}
422+
py::setattr(result, "shape", l);
423+
}
424+
if (buffer.strides != nullptr) {
425+
py::list l;
426+
for (auto i = 0; i < buffer.ndim; i++) {
427+
l.append(buffer.strides[i]);
428+
}
429+
py::setattr(result, "strides", l);
430+
}
431+
if (buffer.suboffsets != nullptr) {
432+
py::list l;
433+
for (auto i = 0; i < buffer.ndim; i++) {
434+
l.append(buffer.suboffsets[i]);
435+
}
436+
py::setattr(result, "suboffsets", l);
437+
}
438+
439+
PyBuffer_Release(&buffer);
440+
return result;
441+
});
271442
}

0 commit comments

Comments
 (0)