Skip to content

Commit 97fd314

Browse files
authored
[XPU] support python streams api for xpu (#73924)
* [XPU] support python streams api for xpu * [XPU] support python streams api for xpu * [XPU] add stream & event unittests
1 parent 97256cd commit 97fd314

File tree

18 files changed

+848
-37
lines changed

18 files changed

+848
-37
lines changed

paddle/fluid/pybind/tensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
903903
const auto &device_id =
904904
paddle::platform::GetXPUCurrentDeviceId();
905905
auto stream = paddle::platform::get_current_stream(device_id);
906-
xpu_wait(stream);
906+
xpu_wait(stream->raw_stream());
907907
int type_idx = static_cast<int>(self.type());
908908
size_t data_size = self.numel() *
909909
framework::SizeOfType(

paddle/fluid/pybind/xpu_streams_py.cc

Lines changed: 209 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,27 @@ namespace py = pybind11;
3333
namespace paddle {
3434
namespace platform {
3535
#ifdef PADDLE_WITH_XPU
36-
XPUStream get_current_stream(int device_id) {
37-
if (device_id == -1) {
38-
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
39-
}
36+
phi::XPUStreamHandle *get_current_stream(int device_id) {
4037
auto place = phi::XPUPlace(device_id);
4138
auto *dev_ctx = static_cast<phi::XPUContext *>(
4239
phi::DeviceContextPool::Instance().Get(place));
4340
dev_ctx->Wait();
44-
return dev_ctx->stream();
41+
return dev_ctx->get_current_stream_handle();
42+
}
43+
44+
phi::XPUStreamHandle *set_current_stream(int idx) {
45+
int device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
46+
auto original_stream = get_current_stream(device_id);
47+
auto place = phi::XPUPlace(device_id);
48+
auto *dev_ctx = static_cast<phi::XPUContext *>(
49+
phi::DeviceContextPool::Instance().Get(place));
50+
dev_ctx->SetCurrentStream(idx);
51+
return original_stream;
4552
}
4653

4754
#endif
4855
} // namespace platform
56+
4957
namespace pybind {
5058
void BindXpuStream(py::module *m_ptr) {
5159
auto &m = *m_ptr;
@@ -69,7 +77,7 @@ void BindXpuStream(py::module *m_ptr) {
6977
#endif
7078
});
7179
m.def(
72-
"_get_current_stream",
80+
"_xpu_get_current_stream",
7381
[](int device_id) {
7482
#ifdef PADDLE_WITH_XPU
7583
if (device_id == -1) {
@@ -79,7 +87,19 @@ void BindXpuStream(py::module *m_ptr) {
7987
return platform::get_current_stream(device_id);
8088
#else
8189
PADDLE_THROW(
82-
common::errors::Unavailable("Paddle is not compiled with CUDA. "
90+
common::errors::Unavailable("Paddle is not compiled with XPU. "
91+
"Cannot visit device synchronize."));
92+
#endif
93+
},
94+
py::return_value_policy::reference);
95+
m.def(
96+
"_xpu_set_current_stream",
97+
[](int stream_id) {
98+
#ifdef PADDLE_WITH_XPU
99+
return platform::set_current_stream(stream_id);
100+
#else
101+
PADDLE_THROW(
102+
common::errors::Unavailable("Paddle is not compiled with XPU. "
83103
"Cannot visit device synchronize."));
84104
#endif
85105
},
@@ -100,12 +120,167 @@ void BindXpuStream(py::module *m_ptr) {
100120
#endif
101121
});
102122

123+
py::class_<phi::XPUStreamHandle>(m, "XPUStream", R"DOC(
124+
The handle of the XPU stream.
125+
126+
Parameters:
127+
device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
128+
If device is None or negative integer, device will be the current device.
129+
If device is positive integer, it must less than the device count. Default: None.
130+
131+
Examples:
132+
.. code-block:: python
133+
134+
>>> # doctest: +REQUIRES(env:XPU)
135+
>>> import paddle
136+
>>> s1 = paddle.device.xpu.Stream(paddle.XPUPlace(0))
137+
>>> s2 = paddle.device.xpu.Stream(0)
138+
>>> s3 = paddle.device.xpu.Stream()
139+
140+
)DOC")
141+
#ifdef PADDLE_WITH_XPU
142+
.def_property_readonly(
143+
"xpu_stream",
144+
[](phi::XPUStreamHandle &self) {
145+
return reinterpret_cast<std::uintptr_t>(self.raw_stream());
146+
})
147+
.def("wait_stream",
148+
[](phi::XPUStreamHandle &self, phi::XPUStreamHandle &other) {
149+
auto *dev_ctx = phi::get_xpu_context();
150+
dev_ctx->StreamWaitStreamInPool(self.id(), other.id());
151+
})
152+
.def("wait_event",
153+
[](phi::XPUStreamHandle &self, phi::XPUEventHandle &other) {
154+
self.wait_event(other.get_event());
155+
})
156+
.def("query",
157+
[](phi::XPUStreamHandle &self) {
158+
PADDLE_THROW(common::errors::Unavailable(
159+
"Query function for XPUStream is not supported now"));
160+
})
161+
.def("record_event",
162+
[](phi::XPUStreamHandle &self, phi::XPUEventHandle *event) {
163+
if (event == nullptr) {
164+
event = new phi::XPUEventHandle();
165+
}
166+
self.record_event(event->get_event());
167+
return event;
168+
})
169+
.def(
170+
"synchronize",
171+
[](phi::XPUStreamHandle &self) { self.synchronize(); },
172+
R"DOC(
173+
Waits for stream tasks to complete.
174+
175+
Examples:
176+
.. code-block:: python
177+
178+
>>> # doctest: +REQUIRES(env:XPU)
179+
>>> import paddle
180+
>>> s = paddle.device.xpu.Stream(paddle.XPUPlace(0), 1)
181+
>>> s.synchronize()
182+
183+
)DOC")
184+
.def_property_readonly(
185+
"place",
186+
[](phi::XPUStreamHandle &self) {
187+
return phi::XPUPlace(platform::GetXPUCurrentDeviceId());
188+
})
189+
.def_property_readonly(
190+
"idx", [](phi::XPUStreamHandle &self) { return self.id(); })
191+
#endif
192+
193+
.def("__init__",
194+
[](phi::XPUStreamHandle &self) {
195+
#ifdef PADDLE_WITH_XPU
196+
new (&self) phi::XPUStreamHandle();
197+
self.Init();
198+
#else
199+
PADDLE_THROW(common::errors::Unavailable(
200+
"Class XPUStream can only be initialized on the XPU "
201+
"platform."));
202+
#endif
203+
})
204+
.def(
205+
"__init__",
206+
[](phi::XPUStreamHandle &self, phi::XPUPlace *place) {
207+
#ifdef PADDLE_WITH_XPU
208+
if (place == nullptr) {
209+
int curr_device_id = platform::GetXPUCurrentDeviceId();
210+
auto place_tmp = phi::XPUPlace(curr_device_id);
211+
new (&self) phi::XPUStreamHandle(place_tmp);
212+
} else {
213+
new (&self) phi::XPUStreamHandle(*place);
214+
}
215+
#else
216+
PADDLE_THROW(common::errors::Unavailable(
217+
"Class XPUStream can only be initialized on the XPU "
218+
"platform."));
219+
#endif
220+
},
221+
py::arg("device") = nullptr)
222+
.def(
223+
"__init__",
224+
[](phi::XPUStreamHandle &self, int device) {
225+
#ifdef PADDLE_WITH_XPU
226+
if (device < 0) {
227+
device = platform::GetXPUCurrentDeviceId();
228+
}
229+
auto place_tmp = phi::XPUPlace(device);
230+
new (&self) phi::XPUStreamHandle(place_tmp);
231+
#else
232+
PADDLE_THROW(common::errors::Unavailable(
233+
"Class XPUStream can only be initialized on the XPU "
234+
"platform."));
235+
#endif
236+
},
237+
py::arg("device") = -1);
238+
py::class_<phi::XPUEventHandle>(m, "XPUEvent", R"DOC(
239+
The handle of the XPU event.
240+
241+
Examples:
242+
.. code-block:: python
243+
244+
>>> # doctest: +REQUIRES(env:XPU)
245+
>>> import paddle
246+
>>> event = paddle.device.xpu.Event()
247+
248+
)DOC")
249+
#ifdef PADDLE_WITH_XPU
250+
.def(
251+
"record",
252+
[](phi::XPUEventHandle &self, phi::XPUStreamHandle *stream) {
253+
if (stream == nullptr) {
254+
auto *dev_ctx = phi::get_xpu_context();
255+
auto stream_handle = dev_ctx->get_current_stream_handle();
256+
self.record(stream_handle->raw_stream());
257+
} else {
258+
self.record(stream->raw_stream());
259+
}
260+
},
261+
py::arg("stream") = nullptr)
262+
.def("query", [](phi::XPUEventHandle &self) { return self.query(); })
263+
.def("elapsed_time",
264+
[](phi::XPUEventHandle &self) {
265+
PADDLE_THROW(common::errors::Unavailable(
266+
"XPUEvent elapsed_time is not supported now"));
267+
})
268+
.def("synchronize", [](phi::XPUEventHandle &self) { self.synchronize(); })
269+
#endif
270+
.def("__init__", [](phi::XPUEventHandle &self) {
271+
#ifdef PADDLE_WITH_XPU
272+
new (&self) phi::XPUEventHandle();
273+
#else
274+
PADDLE_THROW(common::errors::Unavailable(
275+
"Class XPUEvent can only be initialized on the XPU platform."));
276+
#endif
277+
});
103278
#ifdef PADDLE_WITH_XPU
104-
py::class_<XPUStream>(m, "XPUStream", R"DOC(
105-
The handle of the CUDA stream.
279+
py::class_<phi::XPUCUDAStream>(m, "XPUCUDAStream", R"DOC(
280+
The handle of the XPU stream.
106281
107282
Parameters:
108-
device(paddle.CUDAPlace()|int|None, optional): The device which wanted to allocate the stream.
283+
device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
109284
If device is None or negative integer, device will be the current device.
110285
If device is positive integer, it must less than the device count. Default: None.
111286
priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
@@ -114,16 +289,16 @@ void BindXpuStream(py::module *m_ptr) {
114289
Examples:
115290
.. code-block:: python
116291
117-
>>> # doctest: +REQUIRES(env:GPU)
292+
>>> # doctest: +REQUIRES(env:XPU)
118293
>>> import paddle
119-
>>> s1 = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
120-
>>> s2 = paddle.device.cuda.Stream(0, 1)
121-
>>> s3 = paddle.device.cuda.Stream()
294+
>>> s1 = paddle.device.xpu.Stream(paddle.XPUPlace(0), 1)
295+
>>> s2 = paddle.device.xpu.Stream(0, 1)
296+
>>> s3 = paddle.device.xpu.Stream()
122297
123298
)DOC")
124299
.def(
125300
"synchronize",
126-
[](XPUStream &self) { xpu_wait(self); },
301+
[](phi::XPUCUDAStream &self) { self.Synchronize(); },
127302
R"DOC(
128303
Waits for stream tasks to complete.
129304
@@ -135,7 +310,25 @@ void BindXpuStream(py::module *m_ptr) {
135310
>>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
136311
>>> s.synchronize()
137312
138-
)DOC");
313+
)DOC")
314+
.def("__init__",
315+
[](phi::XPUCUDAStream &self, phi::XPUPlace *place, int priority) {
316+
if (priority != 1 && priority != 2) {
317+
PADDLE_THROW(common::errors::InvalidArgument(
318+
"Priority should be 1(high) or 2(normal) "));
319+
}
320+
auto stream_flag =
321+
phi::XPUCUDAStream::StreamFlag::kStreamNonBlocking;
322+
if (place == nullptr) {
323+
int curr_device_id = platform::GetXPUCurrentDeviceId();
324+
auto place_tmp = phi::XPUPlace(curr_device_id);
325+
new (&self)
326+
phi::XPUCUDAStream(place_tmp, priority - 2, stream_flag);
327+
} else {
328+
new (&self)
329+
phi::XPUCUDAStream(*place, priority - 2, stream_flag);
330+
}
331+
});
139332
#endif
140333
}
141334
} // namespace pybind

paddle/fluid/pybind/xpu_streams_py.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,16 @@
1818
#include "pybind11/stl.h"
1919

2020
#ifdef PADDLE_WITH_XPU
21+
#include "paddle/phi/backends/xpu/xpu_context.h"
2122
#include "paddle/phi/core/xpu_cuda_stream.h"
2223
#include "xpu/runtime.h"
2324
#include "xpu/runtime_ex.h"
25+
2426
#else
2527
namespace phi {
2628
class XPUCUDAStream {};
29+
class XPUStreamHandle {};
30+
class XPUEventHandle {};
2731
} // namespace phi
2832
#endif
2933

@@ -32,7 +36,8 @@ namespace py = pybind11;
3236
namespace paddle {
3337
namespace platform {
3438
#ifdef PADDLE_WITH_XPU
35-
XPUStream get_current_stream(int device_id = -1);
39+
phi::XPUStreamHandle* get_current_stream(int device_id = -1);
40+
phi::XPUStreamHandle* set_current_stream(int idx);
3641
#endif
3742
} // namespace platform
3843
namespace pybind {

paddle/phi/api/include/tensor.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ using gpuStream_t = cudaStream_t;
2929
using gpuStream_t = hipStream_t;
3030
#endif
3131

32+
#ifdef PADDLE_WITH_XPU
33+
#include "xpu/runtime.h"
34+
#include "xpu/runtime_ex.h"
35+
#endif
36+
3237
#ifdef PADDLE_WITH_CUSTOM_DEVICE
3338
#include "paddle/phi/backends/stream.h"
3439
#endif
@@ -434,6 +439,10 @@ class PADDLE_API Tensor final {
434439
* @return gpuStream_t
435440
*/
436441
gpuStream_t stream() const;
442+
#elif defined(PADDLE_WITH_XPU)
443+
444+
void record_stream(XPUStream stream) const;
445+
437446
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
438447
/**
439448
* @brief Get the stream where the tensor is currently located

paddle/phi/api/lib/tensor.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ limitations under the License. */
4040
#include "paddle/phi/core/tensor_meta.h"
4141
#include "paddle/phi/core/tensor_utils.h"
4242

43+
#include "paddle/phi/core/memory/malloc.h"
44+
4345
namespace paddle {
4446

4547
using DeviceContextPool = experimental::DeviceContextPool;
@@ -397,6 +399,14 @@ Tensor Tensor::slice(int64_t begin_idx, int64_t end_idx) const {
397399

398400
const std::shared_ptr<phi::TensorBase> &Tensor::impl() const { return impl_; }
399401

402+
#ifdef PADDLE_WITH_XPU
403+
404+
void Tensor::record_stream(XPUStream stream) const {
405+
paddle::memory::RecordStream(
406+
std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->Holder(), stream);
407+
}
408+
409+
#endif
400410
void Tensor::set_impl(const std::shared_ptr<phi::TensorBase> &impl) {
401411
impl_ = impl;
402412
}

0 commit comments

Comments
 (0)