@@ -319,22 +319,57 @@ template<int dimensions>
319319struct PyHalideBuffer {
320320 // Must allocate at least 1, even if d=0
321321 static constexpr int dims_to_allocate = (dimensions < 1) ? 1 : dimensions;
322+ static constexpr const char* get_raw_halide_runtime_buffer_fn = "_get_raw_halide_buffer_t";
322323
323324 Py_buffer py_buf;
324- halide_dimension_t halide_dim[dims_to_allocate];
325- halide_buffer_t halide_buf;
325+ halide_buffer_t* halide_buf = nullptr;
326326 bool py_buf_needs_release = false;
327327 bool needs_device_free = false;
328328
329+ bool unpack_from_halide_buffer(PyObject *py_obj) {
330+ if (!PyObject_HasAttrString(py_obj, get_raw_halide_runtime_buffer_fn)) {
331+ return false;
332+ }
333+
334+ PyObject *py_raw_buffer = PyObject_CallMethod(py_obj, get_raw_halide_runtime_buffer_fn, NULL);
335+ if (!py_raw_buffer) {
336+ PyErr_Clear();
337+ return false;
338+ }
339+
340+ if (!PyLong_Check(py_raw_buffer)) {
341+ Py_DECREF(py_raw_buffer);
342+ return false;
343+ }
344+
345+ uintptr_t py_raw_buffer_ptr = (uintptr_t)PyLong_AsUnsignedLongLong(py_raw_buffer);
346+ Py_DECREF(py_raw_buffer);
347+
348+ if (py_raw_buffer_ptr == 0) {
349+ return false;
350+ }
351+
352+ halide_buf = reinterpret_cast<halide_buffer_t *>(py_raw_buffer_ptr);
353+ return true;
354+ }
355+
329356 bool unpack(PyObject *py_obj, int py_getbuffer_flags, const char *name) {
330- return Halide::PythonRuntime::unpack_buffer(py_obj, py_getbuffer_flags,
331- name, dimensions, py_buf, halide_dim, halide_buf, py_buf_needs_release,
332- needs_device_free);
357+ if (unpack_from_halide_buffer(py_obj)) {
358+ return true;
359+ }
360+ if (Halide::PythonRuntime::unpack_buffer(
361+ py_obj, py_getbuffer_flags, name, dimensions, py_buf,
362+ unpacked_dim, unpacked_buf, py_buf_needs_release,
363+ needs_device_free)) {
364+ halide_buf = &unpacked_buf;
365+ return true;
366+ }
367+ return false;
333368 }
334369
335370 ~PyHalideBuffer() {
336371 if (needs_device_free) {
337- halide_device_free(nullptr, & halide_buf);
372+ halide_device_free(nullptr, halide_buf);
338373 }
339374 if (py_buf_needs_release) {
340375 PyBuffer_Release(&py_buf);
@@ -346,6 +381,10 @@ struct PyHalideBuffer {
346381 PyHalideBuffer &operator=(const PyHalideBuffer &other) = delete;
347382 PyHalideBuffer(PyHalideBuffer &&other) = delete;
348383 PyHalideBuffer &operator=(PyHalideBuffer &&other) = delete;
384+
385+ private:
386+ halide_dimension_t unpacked_dim[dims_to_allocate];
387+ halide_buffer_t unpacked_buf;
349388};
350389
351390} // namespace
@@ -470,7 +509,7 @@ void PythonExtensionGen::compile(const LoweredFunc &f) {
470509 // do a lazy-copy-to-GPU if needed.
471510 for (size_t i = 0 ; i < args.size (); i++) {
472511 if (args[i].is_buffer () && args[i].is_input ()) {
473- dest << indent << " b_" << arg_names[i] << " .halide_buf. set_host_dirty();\n " ;
512+ dest << indent << " b_" << arg_names[i] << " .halide_buf-> set_host_dirty();\n " ;
474513 }
475514 }
476515 dest << indent << " int result;\n " ;
@@ -479,7 +518,7 @@ void PythonExtensionGen::compile(const LoweredFunc &f) {
479518 indent.indent += 2 ;
480519 for (size_t i = 0 ; i < args.size (); i++) {
481520 if (args[i].is_buffer ()) {
482- dest << indent << " & b_" << arg_names[i] << " .halide_buf" ;
521+ dest << indent << " b_" << arg_names[i] << " .halide_buf" ;
483522 } else {
484523 dest << indent << " py_" << arg_names[i] << " " ;
485524 }
@@ -496,7 +535,7 @@ void PythonExtensionGen::compile(const LoweredFunc &f) {
496535 // random garbage. (We need a better solution for this, see https://github.com/halide/Halide/issues/6868)
497536 for (size_t i = 0 ; i < args.size (); i++) {
498537 if (args[i].is_buffer () && args[i].is_output ()) {
499- dest << indent << " if (result == 0) result = halide_copy_to_host(nullptr, & b_" << arg_names[i] << " .halide_buf);\n " ;
538+ dest << indent << " if (result == 0) result = halide_copy_to_host(nullptr, b_" << arg_names[i] << " .halide_buf);\n " ;
500539 }
501540 }
502541 dest << indent << " if (result != 0) {\n " ;
0 commit comments