@@ -225,51 +225,76 @@ namespace pyAMReX
225225 */
226226
227227
228- // DLPack protocol (CPU, NVIDIA GPU, AMD GPU, Intel GPU, etc.)
228+ // DLPack v1.1 protocol (CPU, NVIDIA GPU, AMD GPU, Intel GPU, etc.)
229229 // https://dmlc.github.io/dlpack/latest/
230230 // https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
231231 // https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol
232- .def (" __dlpack__" , [](Array4<T> const &a4, [[maybe_unused]] py::handle stream = py::none ()) {
232+ .def (" __dlpack__" , [](
233+ Array4<T> const &a4
234+ /* TODO:
235+ [[maybe_unused]] py::handle stream,
236+ [[maybe_unused]] std::tuple<int, int> max_version,
237+ [[maybe_unused]] std::tuple<DLDeviceType, int32_t> dl_device,
238+ [[maybe_unused]] bool copy
239+ */
240+ )
241+ {
233242 // Allocate shape/strides arrays
234243 constexpr int ndim = 4 ;
235244 auto const len = length (a4);
236- auto *shape = new int64_t [ndim]{a4.nComp (), len.z , len.y , len.x };
237- auto *strides = new int64_t [ndim]{a4.nstride , a4.kstride , a4.jstride , 1 };
238-
239- // Construct DLTensor
240- auto *dl_tensor = new DLManagedTensor;
241- dl_tensor->dl_tensor .data = const_cast <void *>(static_cast <const void *>(a4.dataPtr ()));
242- dl_tensor->dl_tensor .device = dlpack::detect_device_from_pointer (a4.dataPtr ());
243- dl_tensor->dl_tensor .ndim = ndim;
244- dl_tensor->dl_tensor .dtype = dlpack::get_dlpack_dtype<T>();
245- dl_tensor->dl_tensor .shape = shape;
246- dl_tensor->dl_tensor .strides = strides;
247- dl_tensor->dl_tensor .byte_offset = 0 ;
248- dl_tensor->manager_ctx = nullptr ;
249- dl_tensor->deleter = [](DLManagedTensor *self) {
245+
246+ // Construct DLManagedTensorVersioned (DLPack 1.1 standard)
247+ auto *dl_mgt_tensor = new DLManagedTensorVersioned;
248+ // dl_mgt_tensor->version = DLPackVersion{};
249+ dl_mgt_tensor->version .major = 1 ;
250+ dl_mgt_tensor->version .minor = 1 ;
251+ dl_mgt_tensor->flags = 0 ; // No special flags
252+ dl_mgt_tensor->dl_tensor .data = const_cast <void *>(static_cast <const void *>(a4.dataPtr ()));
253+ dl_mgt_tensor->dl_tensor .device = dlpack::detect_device_from_pointer (a4.dataPtr ());
254+ dl_mgt_tensor->dl_tensor .ndim = ndim;
255+ dl_mgt_tensor->dl_tensor .dtype = dlpack::get_dlpack_dtype<T>();
256+ dl_mgt_tensor->dl_tensor .shape = new int64_t [ndim]{a4.nComp (), len.z , len.y , len.x };
257+ dl_mgt_tensor->dl_tensor .strides = new int64_t [ndim]{a4.nstride , a4.kstride , a4.jstride , 1 };
258+ dl_mgt_tensor->dl_tensor .byte_offset = 0 ;
259+ dl_mgt_tensor->manager_ctx = nullptr ; // TODO: we can increase/decrease the Python ref counter of the producer here
260+ dl_mgt_tensor->deleter = [](DLManagedTensorVersioned *self) {
250261 delete[] self->dl_tensor .shape ;
251262 delete[] self->dl_tensor .strides ;
252263 delete self;
253264 };
254265 // Return as Python capsule
255- return py::capsule (dl_tensor, " dltensor" , [](void * ptr) {
256- auto * tensor = static_cast <DLManagedTensor*>(ptr);
257- tensor->deleter (tensor);
258- });
266+ return py::capsule (
267+ dl_mgt_tensor,
268+ " dltensor" ,
269+ /* [](void* ptr) {
270+ auto* tensor = static_cast<DLManagedTensorVersioned*>(ptr);
271+ tensor->deleter(tensor);
272+ }*/
273+ [](PyObject *capsule)
274+ {
275+ auto *p = static_cast <DLManagedTensorVersioned*>(
276+ PyCapsule_GetPointer (capsule, " dltensor" ));
277+ if (p && p->deleter )
278+ p->deleter (p);
279+ }
280+ );
259281 },
260- py::arg (" stream" ) = py::none (),
282+ // py::arg("stream") = py::none(),
283+ // ... other args & their defaults
261284 R"doc(
262285 DLPack protocol for zero-copy tensor exchange.
263286 See https://dmlc.github.io/dlpack/latest/ for details.
264287 )doc"
265288 )
266289 .def (" __dlpack_device__" , [](Array4<T> const &a4) {
267290 DLDevice device = dlpack::detect_device_from_pointer (a4.dataPtr ());
268- return std::make_tuple (device.device_type , device.device_id );
291+ return std::make_tuple (static_cast < int32_t >( device.device_type ) , device.device_id );
269292 }, R"doc(
270293 DLPack device info (device_type, device_id).
271294 )doc" )
272295
296+
297+
273298 .def (" to_host" , [](Array4<T> const & a4) {
274299 // py::tuple to std::vector
275300 auto const a4i = pyAMReX::array_interface (a4);
0 commit comments