@@ -33,19 +33,27 @@ namespace py = pybind11;
33
33
namespace paddle {
34
34
namespace platform {
35
35
#ifdef PADDLE_WITH_XPU
36
- XPUStream get_current_stream (int device_id) {
37
- if (device_id == -1 ) {
38
- device_id = phi::backends::xpu::GetXPUCurrentDeviceId ();
39
- }
36
+ phi::XPUStreamHandle *get_current_stream (int device_id) {
40
37
auto place = phi::XPUPlace (device_id);
41
38
auto *dev_ctx = static_cast <phi::XPUContext *>(
42
39
phi::DeviceContextPool::Instance ().Get (place));
43
40
dev_ctx->Wait ();
44
- return dev_ctx->stream ();
41
+ return dev_ctx->get_current_stream_handle ();
42
+ }
43
+
44
+ phi::XPUStreamHandle *set_current_stream (int idx) {
45
+ int device_id = phi::backends::xpu::GetXPUCurrentDeviceId ();
46
+ auto original_stream = get_current_stream (device_id);
47
+ auto place = phi::XPUPlace (device_id);
48
+ auto *dev_ctx = static_cast <phi::XPUContext *>(
49
+ phi::DeviceContextPool::Instance ().Get (place));
50
+ dev_ctx->SetCurrentStream (idx);
51
+ return original_stream;
45
52
}
46
53
47
54
#endif
48
55
} // namespace platform
56
+
49
57
namespace pybind {
50
58
void BindXpuStream (py::module *m_ptr) {
51
59
auto &m = *m_ptr;
@@ -69,7 +77,7 @@ void BindXpuStream(py::module *m_ptr) {
69
77
#endif
70
78
});
71
79
m.def (
72
- " _get_current_stream " ,
80
+ " _xpu_get_current_stream " ,
73
81
[](int device_id) {
74
82
#ifdef PADDLE_WITH_XPU
75
83
if (device_id == -1 ) {
@@ -79,7 +87,19 @@ void BindXpuStream(py::module *m_ptr) {
79
87
return platform::get_current_stream (device_id);
80
88
#else
81
89
PADDLE_THROW (
82
- common::errors::Unavailable (" Paddle is not compiled with CUDA. "
90
+ common::errors::Unavailable (" Paddle is not compiled with XPU. "
91
+ " Cannot visit device synchronize." ));
92
+ #endif
93
+ },
94
+ py::return_value_policy::reference);
95
+ m.def (
96
+ " _xpu_set_current_stream" ,
97
+ [](int stream_id) {
98
+ #ifdef PADDLE_WITH_XPU
99
+ return platform::set_current_stream (stream_id);
100
+ #else
101
+ PADDLE_THROW (
102
+ common::errors::Unavailable (" Paddle is not compiled with XPU. "
83
103
" Cannot visit device synchronize." ));
84
104
#endif
85
105
},
@@ -100,12 +120,167 @@ void BindXpuStream(py::module *m_ptr) {
100
120
#endif
101
121
});
102
122
123
+ py::class_<phi::XPUStreamHandle>(m, " XPUStream" , R"DOC(
124
+ The handle of the XPU stream.
125
+
126
+ Parameters:
127
+ device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
128
+ If device is None or negative integer, device will be the current device.
129
+ If device is positive integer, it must less than the device count. Default: None.
130
+
131
+ Examples:
132
+ .. code-block:: python
133
+
134
+ >>> # doctest: +REQUIRES(env:XPU)
135
+ >>> import paddle
136
+ >>> s1 = paddle.device.xpu.Stream(paddle.XPUPlace(0))
137
+ >>> s2 = paddle.device.xpu.Stream(0)
138
+ >>> s3 = paddle.device.xpu.Stream()
139
+
140
+ )DOC" )
141
+ #ifdef PADDLE_WITH_XPU
142
+ .def_property_readonly (
143
+ " xpu_stream" ,
144
+ [](phi::XPUStreamHandle &self) {
145
+ return reinterpret_cast <std::uintptr_t >(self.raw_stream ());
146
+ })
147
+ .def (" wait_stream" ,
148
+ [](phi::XPUStreamHandle &self, phi::XPUStreamHandle &other) {
149
+ auto *dev_ctx = phi::get_xpu_context ();
150
+ dev_ctx->StreamWaitStreamInPool (self.id (), other.id ());
151
+ })
152
+ .def (" wait_event" ,
153
+ [](phi::XPUStreamHandle &self, phi::XPUEventHandle &other) {
154
+ self.wait_event (other.get_event ());
155
+ })
156
+ .def (" query" ,
157
+ [](phi::XPUStreamHandle &self) {
158
+ PADDLE_THROW (common::errors::Unavailable (
159
+ " Query function for XPUStream is not supported now" ));
160
+ })
161
+ .def (" record_event" ,
162
+ [](phi::XPUStreamHandle &self, phi::XPUEventHandle *event) {
163
+ if (event == nullptr ) {
164
+ event = new phi::XPUEventHandle ();
165
+ }
166
+ self.record_event (event->get_event ());
167
+ return event;
168
+ })
169
+ .def (
170
+ " synchronize" ,
171
+ [](phi::XPUStreamHandle &self) { self.synchronize (); },
172
+ R"DOC(
173
+ Waits for stream tasks to complete.
174
+
175
+ Examples:
176
+ .. code-block:: python
177
+
178
+ >>> # doctest: +REQUIRES(env:XPU)
179
+ >>> import paddle
180
+ >>> s = paddle.device.xpu.Stream(paddle.XPUPlace(0), 1)
181
+ >>> s.synchronize()
182
+
183
+ )DOC" )
184
+ .def_property_readonly (
185
+ " place" ,
186
+ [](phi::XPUStreamHandle &self) {
187
+ return phi::XPUPlace (platform::GetXPUCurrentDeviceId ());
188
+ })
189
+ .def_property_readonly (
190
+ " idx" , [](phi::XPUStreamHandle &self) { return self.id (); })
191
+ #endif
192
+
193
+ .def (" __init__" ,
194
+ [](phi::XPUStreamHandle &self) {
195
+ #ifdef PADDLE_WITH_XPU
196
+ new (&self) phi::XPUStreamHandle ();
197
+ self.Init ();
198
+ #else
199
+ PADDLE_THROW (common::errors::Unavailable (
200
+ " Class XPUStream can only be initialized on the XPU "
201
+ " platform." ));
202
+ #endif
203
+ })
204
+ .def (
205
+ " __init__" ,
206
+ [](phi::XPUStreamHandle &self, phi::XPUPlace *place) {
207
+ #ifdef PADDLE_WITH_XPU
208
+ if (place == nullptr ) {
209
+ int curr_device_id = platform::GetXPUCurrentDeviceId ();
210
+ auto place_tmp = phi::XPUPlace (curr_device_id);
211
+ new (&self) phi::XPUStreamHandle (place_tmp);
212
+ } else {
213
+ new (&self) phi::XPUStreamHandle (*place);
214
+ }
215
+ #else
216
+ PADDLE_THROW (common::errors::Unavailable (
217
+ " Class XPUStream can only be initialized on the XPU "
218
+ " platform." ));
219
+ #endif
220
+ },
221
+ py::arg (" device" ) = nullptr )
222
+ .def (
223
+ " __init__" ,
224
+ [](phi::XPUStreamHandle &self, int device) {
225
+ #ifdef PADDLE_WITH_XPU
226
+ if (device < 0 ) {
227
+ device = platform::GetXPUCurrentDeviceId ();
228
+ }
229
+ auto place_tmp = phi::XPUPlace (device);
230
+ new (&self) phi::XPUStreamHandle (place_tmp);
231
+ #else
232
+ PADDLE_THROW (common::errors::Unavailable (
233
+ " Class XPUStream can only be initialized on the XPU "
234
+ " platform." ));
235
+ #endif
236
+ },
237
+ py::arg (" device" ) = -1 );
238
+ py::class_<phi::XPUEventHandle>(m, " XPUEvent" , R"DOC(
239
+ The handle of the XPU event.
240
+
241
+ Examples:
242
+ .. code-block:: python
243
+
244
+ >>> # doctest: +REQUIRES(env:XPU)
245
+ >>> import paddle
246
+ >>> event = paddle.device.xpu.Event()
247
+
248
+ )DOC" )
249
+ #ifdef PADDLE_WITH_XPU
250
+ .def (
251
+ " record" ,
252
+ [](phi::XPUEventHandle &self, phi::XPUStreamHandle *stream) {
253
+ if (stream == nullptr ) {
254
+ auto *dev_ctx = phi::get_xpu_context ();
255
+ auto stream_handle = dev_ctx->get_current_stream_handle ();
256
+ self.record (stream_handle->raw_stream ());
257
+ } else {
258
+ self.record (stream->raw_stream ());
259
+ }
260
+ },
261
+ py::arg (" stream" ) = nullptr )
262
+ .def (" query" , [](phi::XPUEventHandle &self) { return self.query (); })
263
+ .def (" elapsed_time" ,
264
+ [](phi::XPUEventHandle &self) {
265
+ PADDLE_THROW (common::errors::Unavailable (
266
+ " XPUEvent elapsed_time is not supported now" ));
267
+ })
268
+ .def (" synchronize" , [](phi::XPUEventHandle &self) { self.synchronize (); })
269
+ #endif
270
+ .def (" __init__" , [](phi::XPUEventHandle &self) {
271
+ #ifdef PADDLE_WITH_XPU
272
+ new (&self) phi::XPUEventHandle ();
273
+ #else
274
+ PADDLE_THROW (common::errors::Unavailable (
275
+ " Class XPUEvent can only be initialized on the XPU platform." ));
276
+ #endif
277
+ });
103
278
#ifdef PADDLE_WITH_XPU
104
- py::class_<XPUStream >(m, " XPUStream " , R"DOC(
105
- The handle of the CUDA stream.
279
+ py::class_<phi::XPUCUDAStream >(m, " XPUCUDAStream " , R"DOC(
280
+ The handle of the XPU stream.
106
281
107
282
Parameters:
108
- device(paddle.CUDAPlace ()|int|None, optional): The device which wanted to allocate the stream.
283
+ device(paddle.XPUPlace ()|int|None, optional): The device which wanted to allocate the stream.
109
284
If device is None or negative integer, device will be the current device.
110
285
If device is positive integer, it must less than the device count. Default: None.
111
286
priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
@@ -114,16 +289,16 @@ void BindXpuStream(py::module *m_ptr) {
114
289
Examples:
115
290
.. code-block:: python
116
291
117
- >>> # doctest: +REQUIRES(env:GPU )
292
+ >>> # doctest: +REQUIRES(env:XPU )
118
293
>>> import paddle
119
- >>> s1 = paddle.device.cuda .Stream(paddle.CUDAPlace (0), 1)
120
- >>> s2 = paddle.device.cuda .Stream(0, 1)
121
- >>> s3 = paddle.device.cuda .Stream()
294
+ >>> s1 = paddle.device.xpu .Stream(paddle.XPUPlace (0), 1)
295
+ >>> s2 = paddle.device.xpu .Stream(0, 1)
296
+ >>> s3 = paddle.device.xpu .Stream()
122
297
123
298
)DOC" )
124
299
.def (
125
300
" synchronize" ,
126
- [](XPUStream &self) { xpu_wait ( self); },
301
+ [](phi::XPUCUDAStream &self) { self. Synchronize ( ); },
127
302
R"DOC(
128
303
Waits for stream tasks to complete.
129
304
@@ -135,7 +310,25 @@ void BindXpuStream(py::module *m_ptr) {
135
310
>>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
136
311
>>> s.synchronize()
137
312
138
- )DOC" );
313
+ )DOC" )
314
+ .def (" __init__" ,
315
+ [](phi::XPUCUDAStream &self, phi::XPUPlace *place, int priority) {
316
+ if (priority != 1 && priority != 2 ) {
317
+ PADDLE_THROW (common::errors::InvalidArgument (
318
+ " Priority should be 1(high) or 2(normal) " ));
319
+ }
320
+ auto stream_flag =
321
+ phi::XPUCUDAStream::StreamFlag::kStreamNonBlocking ;
322
+ if (place == nullptr ) {
323
+ int curr_device_id = platform::GetXPUCurrentDeviceId ();
324
+ auto place_tmp = phi::XPUPlace (curr_device_id);
325
+ new (&self)
326
+ phi::XPUCUDAStream (place_tmp, priority - 2 , stream_flag);
327
+ } else {
328
+ new (&self)
329
+ phi::XPUCUDAStream (*place, priority - 2 , stream_flag);
330
+ }
331
+ });
139
332
#endif
140
333
}
141
334
} // namespace pybind
0 commit comments