diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index 793d0bbdf6e695..e01964966d2727 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -265,7 +265,7 @@ ::DLDataType PhiDataTypeToDLDataType(phi::DataType dtype) { framework::TransToProtoVarType(dtype)); } -phi::Place DLDeviceToPlace(const DLDevice &dl_device) { +phi::Place DLDeviceToPlace(const ::DLDevice &dl_device) { phi::Place place; if (dl_device.device_type == kDLCPU) { place = phi::CPUPlace(); @@ -279,7 +279,7 @@ phi::Place DLDeviceToPlace(const DLDevice &dl_device) { return place; } -DLDevice PlaceToDLDevice(const phi::Place &place) { +::DLDevice PlaceToDLDevice(const phi::Place &place) { return phi::VisitPlace(place, internal::DLDeviceVisitor()); } diff --git a/paddle/fluid/framework/dlpack_tensor.h b/paddle/fluid/framework/dlpack_tensor.h index e287ce342fa78c..ed799a192f83f9 100644 --- a/paddle/fluid/framework/dlpack_tensor.h +++ b/paddle/fluid/framework/dlpack_tensor.h @@ -29,15 +29,17 @@ and paddle/phi/api/lib/tensor_utils.cc */ using Deleter = std::function; -phi::Place DLDeviceToPlace(const DLDevice& device); -DLDevice PlaceToDLDevice(const phi::Place& place); +::DLDataType PhiDataTypeToDLDataType(phi::DataType dtype); +phi::DataType DLDataTypeToPhiDataType(::DLDataType type); +phi::Place DLDeviceToPlace(const ::DLDevice& device); +::DLDevice PlaceToDLDevice(const phi::Place& place); TEST_API DLManagedTensor* ToDLPack(const phi::DenseTensor& src, uint64_t flags = 0); -DLManagedTensorVersioned* ToDLPackVersioned(const phi::DenseTensor& src, - uint64_t flags = 0); -TEST_API phi::DenseTensor FromDLPack(DLManagedTensor* src); -phi::DenseTensor FromDLPackVersioned(DLManagedTensorVersioned* src); +::DLManagedTensorVersioned* ToDLPackVersioned(const phi::DenseTensor& src, + uint64_t flags = 0); +TEST_API phi::DenseTensor FromDLPack(::DLManagedTensor* src); +phi::DenseTensor FromDLPackVersioned(::DLManagedTensorVersioned* src); // A traits to support both DLManagedTensor and DLManagedTensorVersioned template diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d3b17ad377b7cf..d2c7b52f272af4 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -763,6 +763,64 @@ class PyLayerBlockContextManager { PyLayerBlockContextManager() = default; }; +int DLPackFromPyObject(void *py_obj, + DLManagedTensorVersioned **out, + void **env_stream) { + try { + py::handle handle(static_cast(py_obj)); + paddle::Tensor tensor = handle.cast(); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_CUSTOM_DEVICE) + if (env_stream != nullptr && tensor.is_gpu()) { + int device_index = tensor.place().GetDeviceId(); + *env_stream = platform::get_current_stream(device_index)->raw_stream(); + } +#endif + std::shared_ptr dense_tensor = + std::static_pointer_cast(tensor.impl()); + *out = paddle::framework::ToDLPackVersioned(*dense_tensor); + return 0; + } catch (const std::exception &e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } +} + +int DLPackToPyObject(DLManagedTensorVersioned *src, void **py_obj_out) { + try { + phi::DenseTensor dense_tensor = paddle::framework::FromDLPackVersioned(src); + paddle::Tensor tensor(std::make_shared(dense_tensor)); + egr::EagerUtils::autograd_meta(&tensor)->SetPersistable(false); + *py_obj_out = ToPyObject(tensor); + return 0; + } catch (const std::exception &e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } +} + +int DLPackTensorAllocator(::DLTensor *prototype, + ::DLManagedTensorVersioned **out, + void *error_ctx, + void (*SetError)(void *error_ctx, + const char *kind, + const char *message)) { + try { + phi::IntArray shape(prototype->shape, prototype->ndim); + phi::Place place(paddle::framework::DLDeviceToPlace(prototype->device)); + phi::DataType dtype = + paddle::framework::DLDataTypeToPhiDataType(prototype->dtype); + paddle::Tensor tensor = paddle::empty(shape, dtype, place); + std::shared_ptr dense_tensor = + std::static_pointer_cast(tensor.impl()); + *out = paddle::framework::ToDLPackVersioned(*dense_tensor); + return 0; + } catch (const std::exception &e) { + SetError(error_ctx, "DLPackTensorAllocator", e.what()); + return -1; + } +} + // NOTE: use to load file by Mmap enum MMapLoadModes { ALLOCATOR_MAPPED_SHARED = 1, @@ -1773,6 +1831,18 @@ PYBIND11_MODULE(libpaddle, m) { dl_device.device_id); }); + m.def("dlpack_from_pyobject_ptr", []() -> int64_t { + return reinterpret_cast(DLPackFromPyObject); + }); + + m.def("dlpack_to_pyobject_ptr", []() -> int64_t { + return reinterpret_cast(DLPackToPyObject); + }); + + m.def("dlpack_tensor_allocator_ptr", []() -> int64_t { + return reinterpret_cast(DLPackTensorAllocator); + }); + m.def("from_dlpack", [](py::object data) { if (PyCapsule_IsValid(data.ptr(), DLPackTraits::capsule)) { diff --git a/python/paddle/base/dygraph/tensor_patch_methods.py b/python/paddle/base/dygraph/tensor_patch_methods.py index e19d5e7f8405d1..2650ebd77f5a29 100644 --- a/python/paddle/base/dygraph/tensor_patch_methods.py +++ b/python/paddle/base/dygraph/tensor_patch_methods.py @@ -1586,6 +1586,9 @@ def __tvm_ffi_env_stream__(self) -> int: ("__dlpack_device__", __dlpack_device__), ("get_device", get_device), ("__tvm_ffi_env_stream__", __tvm_ffi_env_stream__), + ("__c_dlpack_from_pyobject__", core.dlpack_from_pyobject_ptr()), + ("__c_dlpack_to_pyobject__", core.dlpack_to_pyobject_ptr()), + ("__c_dlpack_tensor_allocator__", core.dlpack_tensor_allocator_ptr()), ): setattr(core.eager.Tensor, method_name, method)