Skip to content

Commit 59fec5d

Browse files
authored
[cherry-pick 2.4] Fix to_dlpack (#50138) (#50250)
* Fix to_dlpack (#50138) * fix to_dlpack for loop * fix reference count * fix conflicts
1 parent b50f04a commit 59fec5d

File tree

4 files changed

+92
-31
lines changed

4 files changed

+92
-31
lines changed

paddle/fluid/framework/dlpack_tensor.cc

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,59 @@ struct DLDeviceVisitor
134134
};
135135
} // namespace internal
136136

137-
DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
137+
struct PaddleDLMTensor {
138+
phi::DenseTensor handle;
139+
DLManagedTensor tensor;
140+
};
141+
142+
void deleter(DLManagedTensor *arg) {
143+
delete[] arg->dl_tensor.shape;
144+
delete[] arg->dl_tensor.strides;
145+
delete static_cast<PaddleDLMTensor *>(arg->manager_ctx);
146+
}
147+
148+
DLManagedTensor *toDLPack(const phi::DenseTensor &src) {
149+
PaddleDLMTensor *pdDLMTensor(new PaddleDLMTensor);
150+
pdDLMTensor->handle = const_cast<phi::DenseTensor &>(src);
151+
pdDLMTensor->tensor.manager_ctx = pdDLMTensor;
152+
pdDLMTensor->tensor.deleter = &deleter;
153+
pdDLMTensor->tensor.dl_tensor.data = const_cast<void *>(src.data());
154+
155+
// init ndim
156+
using DimType = decltype(pdDLMTensor->tensor.dl_tensor.ndim); // int
157+
pdDLMTensor->tensor.dl_tensor.ndim = static_cast<DimType>(src.dims().size());
158+
DimType ndim = pdDLMTensor->tensor.dl_tensor.ndim;
159+
160+
// init shape
161+
auto shape = new int64_t[ndim];
162+
for (DimType i = 0; i < ndim; ++i) {
163+
shape[i] = src.dims()[i];
164+
}
165+
pdDLMTensor->tensor.dl_tensor.shape = shape;
166+
167+
// init stride
168+
auto strides = new int64_t[ndim];
169+
for (DimType i = 0; i < ndim; ++i) {
170+
strides[i] = 1;
171+
}
172+
for (DimType i = ndim - 2; i >= 0; --i) {
173+
strides[i] = shape[i + 1] * strides[i + 1];
174+
}
175+
pdDLMTensor->tensor.dl_tensor.strides = strides;
176+
177+
// init device, DLDevice type with device_type and device_id
178+
auto place = src.place();
179+
pdDLMTensor->tensor.dl_tensor.device =
180+
paddle::platform::VisitPlace(place, internal::DLDeviceVisitor());
181+
182+
pdDLMTensor->tensor.dl_tensor.dtype = internal::GetDLDataTypeFromTypeIndex(
183+
framework::TransToProtoVarType(src.dtype()));
184+
185+
pdDLMTensor->tensor.dl_tensor.byte_offset = 0;
186+
return &(pdDLMTensor->tensor);
187+
}
188+
189+
DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) {
138190
// init data, data buffer
139191
t_.data = const_cast<void *>(tensor.data());
140192

paddle/fluid/framework/dlpack_tensor.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class DLPackTensor {
2828
std::remove_reference<decltype(::DLTensor::shape[0])>::type; // int64_t
2929

3030
// lanes is only used in CPU to enable vectorization
31-
explicit DLPackTensor(const Tensor& tensor, LaneType lanes = 1);
31+
explicit DLPackTensor(const phi::DenseTensor& tensor, LaneType lanes = 1);
3232

3333
inline operator const ::DLTensor&() const { return t_; }
3434

@@ -44,5 +44,7 @@ class DLPackTensor {
4444
ShapeType shape_[DDim::kMaxRank];
4545
};
4646

47+
DLManagedTensor* toDLPack(const phi::DenseTensor& src);
48+
4749
} // namespace framework
4850
} // namespace paddle

paddle/fluid/pybind/tensor.cc

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -472,23 +472,16 @@ void BindTensor(pybind11::module &m) { // NOLINT
472472
print(t.shape()) # [5, 30]
473473
)DOC")
474474
.def("_to_dlpack",
475-
[](framework::Tensor &self) {
476-
DLPackTensor dlpack_tensor(self, 1);
477-
DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor();
478-
auto capsule = py::capsule(
475+
[](phi::DenseTensor &self) {
476+
DLManagedTensor *dmt = framework::toDLPack(self);
477+
auto capsule = pybind11::capsule(
479478
static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) {
480-
if (ptr) {
481-
auto dltensor = new DLManagedTensor;
482-
try {
483-
dltensor = reinterpret_cast<DLManagedTensor *>(
484-
PyCapsule_GetPointer(ptr, "used_dltensor"));
485-
return;
486-
} catch (...) {
487-
dltensor = reinterpret_cast<DLManagedTensor *>(
488-
PyCapsule_GetPointer(ptr, "dltensor"));
489-
}
490-
dltensor->deleter(dltensor);
479+
if (!PyCapsule_IsValid(ptr, "dltensor")) {
480+
return;
491481
}
482+
DLManagedTensor *dmt = static_cast<DLManagedTensor *>(
483+
PyCapsule_GetPointer(ptr, "dltensor"));
484+
dmt->deleter(dmt);
492485
});
493486
return capsule;
494487
})

python/paddle/tests/test_dlpack.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,20 @@
2222

2323

2424
class TestDLPack(unittest.TestCase):
25-
2625
def func_test_dlpack_dygraph(self):
2726
paddle.disable_static()
2827
tensor = paddle.to_tensor(np.array([1, 2, 3, 4]).astype('int'))
2928
dlpack = paddle.utils.dlpack.to_dlpack(tensor)
3029
out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack)
3130
if paddle.fluid.framework.in_dygraph_mode():
3231
self.assertTrue(
33-
isinstance(out_from_dlpack, paddle.fluid.core.eager.Tensor))
32+
isinstance(out_from_dlpack, paddle.fluid.core.eager.Tensor)
33+
)
3434
else:
3535
self.assertTrue(isinstance(out_from_dlpack, paddle.Tensor))
36-
np.testing.assert_array_equal(np.array(out_from_dlpack),
37-
np.array([1, 2, 3, 4]).astype('int'))
36+
np.testing.assert_array_equal(
37+
np.array(out_from_dlpack), np.array([1, 2, 3, 4]).astype('int')
38+
)
3839

3940
def test_dlpack_dygraph(self):
4041
with _test_eager_guard():
@@ -58,26 +59,32 @@ def test_dlpack_tensor_larger_than_2dim(self):
5859
def test_dlpack_static(self):
5960
paddle.enable_static()
6061
tensor = fluid.create_lod_tensor(
61-
np.array([[1], [2], [3], [4]]).astype('int'), [[1, 3]],
62-
fluid.CPUPlace())
62+
np.array([[1], [2], [3], [4]]).astype('int'),
63+
[[1, 3]],
64+
fluid.CPUPlace(),
65+
)
6366
dlpack = paddle.utils.dlpack.to_dlpack(tensor)
6467
out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack)
6568
self.assertTrue(isinstance(out_from_dlpack, fluid.core.Tensor))
6669
np.testing.assert_array_equal(
6770
np.array(out_from_dlpack),
68-
np.array([[1], [2], [3], [4]]).astype('int'))
71+
np.array([[1], [2], [3], [4]]).astype('int'),
72+
)
6973

7074
# when build with cuda
7175
if core.is_compiled_with_cuda():
7276
gtensor = fluid.create_lod_tensor(
73-
np.array([[1], [2], [3], [4]]).astype('int'), [[1, 3]],
74-
fluid.CUDAPlace(0))
77+
np.array([[1], [2], [3], [4]]).astype('int'),
78+
[[1, 3]],
79+
fluid.CUDAPlace(0),
80+
)
7581
gdlpack = paddle.utils.dlpack.to_dlpack(gtensor)
7682
gout_from_dlpack = paddle.utils.dlpack.from_dlpack(gdlpack)
7783
self.assertTrue(isinstance(gout_from_dlpack, fluid.core.Tensor))
7884
np.testing.assert_array_equal(
7985
np.array(gout_from_dlpack),
80-
np.array([[1], [2], [3], [4]]).astype('int'))
86+
np.array([[1], [2], [3], [4]]).astype('int'),
87+
)
8188

8289
def func_test_dlpack_dtype_conversion(self):
8390
paddle.disable_static()
@@ -104,7 +111,8 @@ def func_test_dlpack_dtype_conversion(self):
104111
for dtype in complex_dtypes:
105112
x = paddle.to_tensor(
106113
[[1 + 6j, 2 + 5j, 3 + 4j], [4 + 3j, 5 + 2j, 6 + 1j]],
107-
dtype=dtype)
114+
dtype=dtype,
115+
)
108116
dlpack = paddle.utils.dlpack.to_dlpack(x)
109117
o = paddle.utils.dlpack.from_dlpack(dlpack)
110118
self.assertEqual(x.dtype, o.dtype)
@@ -115,12 +123,18 @@ def test_dlpack_dtype_conversion(self):
115123
self.func_test_dlpack_dtype_conversion()
116124
self.func_test_dlpack_dtype_conversion()
117125

126+
def test_to_dlpack_for_loop(self):
127+
# See Paddle issue 50120
128+
for i in range(10):
129+
x = paddle.rand([3, 5])
130+
dlpack = paddle.utils.dlpack.to_dlpack(x)
118131

119-
class TestRaiseError(unittest.TestCase):
120132

133+
class TestRaiseError(unittest.TestCase):
121134
def func_test_from_dlpack_raise_type_error(self):
122-
self.assertRaises(TypeError, paddle.utils.dlpack.from_dlpack,
123-
np.zeros(5))
135+
self.assertRaises(
136+
TypeError, paddle.utils.dlpack.from_dlpack, np.zeros(5)
137+
)
124138

125139
def test_from_dlpack_raise_type_error(self):
126140
with _test_eager_guard():

0 commit comments

Comments
 (0)