Skip to content

Commit 17d6201

Browse files
authored
[Compat] Remove device parameter from Event class and update documentation (PaddlePaddle#76379)
1 parent bead0c2 commit 17d6201

File tree

2 files changed

+37
-30
lines changed

2 files changed

+37
-30
lines changed

python/paddle/device/__init__.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,27 +1018,46 @@ class Event:
10181018
A device event wrapper around StreamBase.
10191019
10201020
Args:
1021-
device(str|paddle.CUDAPlace(n)|paddle.CustomPlace(n)|None): Which device the stream run on. If device is None, the device is the current device. Default: None.
1022-
It can be ``gpu``, ``gpu:x``, ``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevice,
1023-
where ``x`` is the index of the GPUs, XPUs. And it can be paddle.CUDAPlace(n) or paddle.CustomPlace(n).
10241021
enable_timing (bool, optional): indicates if the event should measure time, default is False
10251022
blocking (bool, optional): if True, ``wait`` will be blocking, default is False
10261023
interprocess (bool): if True, the event can be shared between processes, default is False
10271024
10281025
Returns:
10291026
Event: The event.
10301027
1028+
Note:
1029+
The `device` parameter has been removed in the latest version. The event will always use the current device context.
1030+
Previously, you could specify the device like:
1031+
```python
1032+
# Old usage (no longer supported)
1033+
e = paddle.device.Event(device="gpu:0")
1034+
```
1035+
Now it will automatically use the current device:
1036+
```python
1037+
# New usage
1038+
paddle.set_device("gpu:0") # Set device first
1039+
e = paddle.device.Event() # Will use gpu:0
1040+
```
1041+
1042+
paddle.device.Event is equivalent to paddle.cuda.Event.
1043+
10311044
Examples:
10321045
.. code-block:: python
10331046
10341047
>>> # doctest: +REQUIRES(env:CUSTOM_DEVICE)
10351048
>>> import paddle
10361049
10371050
>>> paddle.set_device('custom_cpu')
1038-
>>> e1 = paddle.device.Event()
1039-
>>> e2 = paddle.device.Event('custom_cpu')
1040-
>>> e3 = paddle.device.Event('custom_cpu:0')
1041-
>>> e4 = paddle.device.Event(paddle.CustomPlace('custom_cpu', 0))
1051+
>>> e1 = paddle.device.Event() # Uses current device (custom_cpu)
1052+
>>>
1053+
>>> # Old usage (no longer supported):
1054+
>>> # e2 = paddle.device.Event('custom_cpu')
1055+
>>> # e3 = paddle.device.Event('custom_cpu:0')
1056+
>>> # e4 = paddle.device.Event(paddle.CustomPlace('custom_cpu', 0))
1057+
>>>
1058+
>>> # New equivalent usage:
1059+
>>> paddle.set_device('custom_cpu:0')
1060+
>>> e5 = paddle.device.Event() # Uses custom_cpu:0
10421061
10431062
'''
10441063

@@ -1048,17 +1067,11 @@ class Event:
10481067

10491068
def __init__(
10501069
self,
1051-
device: PlaceLike | None = None,
10521070
enable_timing: bool = False,
10531071
blocking: bool = False,
10541072
interprocess: bool = False,
10551073
) -> None:
1056-
if device is None:
1057-
self.device = paddle.framework._current_expected_place_()
1058-
elif isinstance(device, str):
1059-
self.device = paddle.device._convert_to_place(device)
1060-
else:
1061-
self.device = device
1074+
self.device = paddle.framework._current_expected_place_()
10621075

10631076
device_id = (
10641077
self.device.get_device_id()
@@ -1340,7 +1353,7 @@ def record_event(self, event: Event | None = None) -> Event:
13401353
13411354
'''
13421355
if event is None:
1343-
event = Event(self.device)
1356+
event = Event()
13441357
event.record(self)
13451358
return event
13461359

test/compat/test_event_stream_apis.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,10 @@ def _test_event_stream_apis_impl(self, device_str):
9090
event1 = paddle.device.Event()
9191
self.assertIsInstance(event1, paddle.device.Event)
9292

93-
event2 = paddle.device.Event(device=device_str, enable_timing=True)
93+
event2 = paddle.device.Event(enable_timing=True)
9494
self.assertIsInstance(event2, paddle.device.Event)
9595

96-
event3 = paddle.device.Event(
97-
device=device_str, enable_timing=True, blocking=True
98-
)
96+
event3 = paddle.device.Event(enable_timing=True, blocking=True)
9997
self.assertIsInstance(event3, paddle.device.Event)
10098

10199
# Test Stream creation with different parameters
@@ -159,12 +157,8 @@ def _test_event_stream_apis_impl(self, device_str):
159157
# Test Event.elapsed_time()
160158
if hasattr(event1, 'event_base') and hasattr(event2, 'event_base'):
161159
# Create events with timing enabled
162-
start_event = paddle.device.Event(
163-
device=device_str, enable_timing=True
164-
)
165-
end_event = paddle.device.Event(
166-
device=device_str, enable_timing=True
167-
)
160+
start_event = paddle.device.Event(enable_timing=True)
161+
end_event = paddle.device.Event(enable_timing=True)
168162

169163
# Record start event
170164
start_event.record()
@@ -243,7 +237,7 @@ def _test_event_stream_apis_impl(self, device_str):
243237
def test_event_stream_error_handling(self):
244238
"""Test Event and Stream error handling."""
245239
# Test with invalid device types
246-
with self.assertRaises(ValueError):
240+
with self.assertRaises(TypeError):
247241
paddle.device.Event(device='invalid_device:0')
248242

249243
with self.assertRaises(ValueError):
@@ -258,8 +252,8 @@ def test_event_stream_error_handling(self):
258252
)
259253
paddle.device.set_device(device_str)
260254

261-
event1 = paddle.device.Event(device=device_str)
262-
event2 = paddle.device.Event(device=device_str)
255+
event1 = paddle.device.Event()
256+
event2 = paddle.device.Event()
263257

264258
# Should not raise exception even if events are not recorded
265259
try:
@@ -320,8 +314,8 @@ def test_event_stream_timing_functionality(self):
320314
paddle.device.set_device(device_str)
321315

322316
# Create events with timing enabled
323-
start_event = paddle.device.Event(device=device_str, enable_timing=True)
324-
end_event = paddle.device.Event(device=device_str, enable_timing=True)
317+
start_event = paddle.device.Event(enable_timing=True)
318+
end_event = paddle.device.Event(enable_timing=True)
325319

326320
# Create a stream for work execution
327321
stream = paddle.device.Stream(device=device_str)

0 commit comments

Comments
 (0)