Skip to content

Commit 6fd1723

Browse files
authored
[Compat] Compatible with API development (PaddlePaddle#76247)
1 parent f956a3b commit 6fd1723

File tree

8 files changed

+157
-123
lines changed

8 files changed

+157
-123
lines changed

python/paddle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def new_init(self, *args, **kwargs):
271271
set_default_dtype,
272272
)
273273
from .framework.random import (
274+
Generator,
274275
get_cuda_rng_state,
275276
get_rng_state,
276277
seed,
@@ -1484,6 +1485,7 @@ def __dir__(self):
14841485
'conv3d',
14851486
'manual_seed',
14861487
'softmax',
1488+
'Generator',
14871489
'adaptive_avg_pool1d',
14881490
'autocast',
14891491
]

python/paddle/compat/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -955,11 +955,7 @@ def __getattr__(self, name: str) -> Any:
955955
return getattr(self._original_module, name)
956956

957957

958-
GLOBAL_OVERRIDES = {
959-
"torch.Generator": create_fake_class(
960-
"Generator", {"manual_seed": create_fake_function("manual_seed")}
961-
),
962-
}
958+
GLOBAL_OVERRIDES = {}
963959

964960

965961
def _is_torch_module(name: str) -> bool:

python/paddle/cuda/__init__.py

Lines changed: 4 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
from paddle.device import (
2424
Event,
2525
Stream,
26+
StreamContext,
2627
_device_to_paddle as _device_to_paddle,
2728
amp, # noqa: F401
29+
current_device,
2830
device,
2931
is_available as _device_is_available,
3032
is_bf16_supported,
@@ -33,7 +35,7 @@
3335
manual_seed_all as device_manual_seed_all,
3436
reset_peak_memory_stats,
3537
set_stream,
36-
stream_guard as _PaddleStreamGuard,
38+
stream,
3739
)
3840
from paddle.tensor.creation import (
3941
BFloat16Tensor,
@@ -284,39 +286,6 @@ def manual_seed_all(seed: int) -> None:
284286
device_manual_seed_all(seed)
285287

286288

287-
class StreamContext(_PaddleStreamGuard):
288-
"""
289-
Notes:
290-
This API only supports dynamic graph mode currently.
291-
A context manager that specifies the current stream context by the given stream.
292-
293-
Args:
294-
stream(Stream, optional): the selected stream. If stream is None, just yield.
295-
296-
Returns:
297-
None.
298-
299-
Examples:
300-
.. code-block:: python
301-
302-
>>> # doctest: +REQUIRES(env:CUSTOM_DEVICE)
303-
>>> import paddle
304-
305-
>>> paddle.set_device('cuda')
306-
>>> s = paddle.cuda.Stream()
307-
>>> data1 = paddle.ones(shape=[20])
308-
>>> data2 = paddle.ones(shape=[20])
309-
>>> data3 = data1 + data2
310-
>>> with paddle.cuda.StreamContext(s):
311-
... s.wait_stream(paddle.cuda.current_stream()) # type: ignore[attr-defined]
312-
... data4 = data1 + data3
313-
314-
"""
315-
316-
def __init__(self, stream: paddle_device.Stream):
317-
super().__init__(stream)
318-
319-
320289
def get_rng_state(device: DeviceLike | None = None) -> core.GeneratorState:
321290
"""
322291
Return the random number generator state of the specified device.
@@ -369,40 +338,6 @@ def set_rng_state(
369338
paddle_device.set_rng_state(new_state, device)
370339

371340

372-
def stream(stream_obj: paddle_device.Stream | None) -> StreamContext:
373-
'''
374-
375-
Notes:
376-
This API only supports dynamic graph mode currently.
377-
A context manager that specifies the current stream context by the given stream.
378-
379-
Args:
380-
stream(Stream, optional): the selected stream. If stream is None, just yield.
381-
382-
Returns:
383-
None.
384-
385-
Examples:
386-
.. code-block:: python
387-
388-
>>> # doctest: +REQUIRES(env:CUSTOM_DEVICE)
389-
>>> import paddle
390-
391-
>>> paddle.set_device('cuda')
392-
>>> s = paddle.cuda.Stream()
393-
>>> data1 = paddle.ones(shape=[20])
394-
>>> data2 = paddle.ones(shape=[20])
395-
>>> data3 = data1 + data2
396-
397-
>>> with paddle.cuda.stream(s):
398-
... s.wait_stream(paddle.cuda.current_stream())
399-
... data4 = data1 + data3
400-
>>> print(data4)
401-
402-
'''
403-
return StreamContext(stream_obj)
404-
405-
406341
class nvtx:
407342
"""Namespace for NVTX marker operations."""
408343

@@ -559,35 +494,6 @@ def mem_get_info(device: DeviceLike = None) -> tuple[int, int]:
559494
return cudart().cudaMemGetInfo(device_id)
560495

561496

562-
def current_device() -> int:
563-
"""
564-
Return the index of a currently selected device.
565-
566-
Returns:
567-
int: The index of the currently selected device.
568-
569-
Examples:
570-
.. code-block:: python
571-
572-
>>> # doctest: +REQUIRES(env:GPU)
573-
>>> import paddle
574-
>>> device_id = paddle.cuda.current_device()
575-
>>> print(f"Current device index: {device_id}")
576-
"""
577-
# Use paddle.device.get_device() to get the current device string
578-
device_str = paddle_device.get_device()
579-
580-
# Parse the device string to extract the device index
581-
# Format examples: 'gpu:0', 'xpu:0', 'custom_device:0'
582-
if ':' in device_str:
583-
device_id = int(device_str.split(':')[1])
584-
else:
585-
# If no device index is specified, default to 0
586-
device_id = 0
587-
588-
return device_id
589-
590-
591497
def device_count() -> int:
592498
"""
593499
Return the number of devices available.
@@ -972,4 +878,5 @@ def get_stream_from_external(
972878
"max_memory_allocated",
973879
"reset_peak_memory_stats",
974880
"Event",
881+
"StreamContext",
975882
]

python/paddle/device/__init__.py

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,35 @@ def __exit__(
516516
return False
517517

518518

519+
def current_device() -> int:
520+
"""
521+
Return the index of a currently selected device.
522+
523+
Returns:
524+
int: The index of the currently selected device.
525+
526+
Examples:
527+
.. code-block:: python
528+
529+
>>> # doctest: +REQUIRES(env:GPU)
530+
>>> import paddle
531+
>>> device_id = paddle.device.current_device() # this is equivalent to paddle.cuda.current_device()
532+
>>> print(f"Current device index: {device_id}")
533+
"""
534+
# Use paddle.device.get_device() to get the current device string
535+
device_str = get_device()
536+
537+
# Parse the device string to extract the device index
538+
# Format examples: 'gpu:0', 'xpu:0', 'custom_device:0'
539+
if ':' in device_str:
540+
device_id = int(device_str.split(':')[1])
541+
else:
542+
# If no device index is specified, default to 0
543+
device_id = 0
544+
545+
return device_id
546+
547+
519548
def is_bf16_supported(including_emulation: bool = True) -> bool:
520549
"""
521550
Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16.
@@ -538,9 +567,26 @@ def is_bf16_supported(including_emulation: bool = True) -> bool:
538567
539568
"""
540569
# including_emulation is not used here, but kept for compatibility with the original implementation
541-
return core.is_bfloat16_supported(
542-
paddle.framework._current_expected_place()
543-
)
570+
if core.is_bfloat16_supported(paddle.framework._current_expected_place()):
571+
return True
572+
573+
# If CUDA is not available, than it does not support bf16 either
574+
if not is_available():
575+
return False
576+
577+
device = get_device()
578+
579+
# Check for CUDA version and device compute capability.
580+
# This is a fast way to check for it.
581+
if not including_emulation:
582+
return False
583+
584+
# Finally try to create a bfloat16 device.
585+
try:
586+
paddle.tensor([1.0], dtype=paddle.bfloat16, device=device)
587+
return True
588+
except:
589+
return False
544590

545591

546592
def set_device(device: PlaceLike | int) -> PlaceLike:
@@ -1584,7 +1630,7 @@ class stream_guard:
15841630
>>> data1 = paddle.ones(shape=[20])
15851631
>>> data2 = paddle.ones(shape=[20])
15861632
>>> data3 = data1 + data2
1587-
>>> with paddle.device.stream_guard(s):
1633+
>>> with paddle.device.stream_guard(s):# this is equivalent to paddle.cuda.StreamContext(s) and paddle.device.StreamContext(s)
15881634
... s.wait_stream(paddle.device.default_stream()) # type: ignore[attr-defined]
15891635
... data4 = data1 + data3
15901636
@@ -1627,6 +1673,43 @@ def __exit__(
16271673
set_stream(self.src_prev_stream)
16281674

16291675

1676+
StreamContext = stream_guard
1677+
1678+
1679+
def stream(stream: Stream | None) -> stream_guard:
1680+
'''
1681+
1682+
Notes:
1683+
This API only supports dynamic graph mode currently.
1684+
A context manager that specifies the current stream context by the given stream.
1685+
1686+
Args:
1687+
stream(Stream, optional): the selected stream. If stream is None, just yield.
1688+
1689+
Returns:
1690+
None.
1691+
1692+
Examples:
1693+
.. code-block:: python
1694+
1695+
>>> # doctest: +REQUIRES(env:CUSTOM_DEVICE)
1696+
>>> import paddle
1697+
1698+
>>> paddle.set_device('cuda')
1699+
>>> s = paddle.device.Stream()
1700+
>>> data1 = paddle.ones(shape=[20])
1701+
>>> data2 = paddle.ones(shape=[20])
1702+
>>> data3 = data1 + data2
1703+
1704+
>>> with paddle.device.stream(s): # this is equivalent to paddle.cuda.stream(s)
1705+
... s.wait_stream(paddle.cuda.current_stream())
1706+
... data4 = data1 + data3
1707+
>>> print(data4)
1708+
1709+
'''
1710+
return StreamContext(stream)
1711+
1712+
16301713
class device_guard:
16311714
'''
16321715
@@ -1900,15 +1983,16 @@ def reset_peak_memory_stats(device: PlaceLike | int | None = None) -> None:
19001983
It sets the peak memory usage back to zero for all devices.
19011984
19021985
Example:
1903-
>>> # doctest: +REQUIRES(env:GPU)
1904-
>>> import paddle
1905-
>>> paddle.device.set_device('gpu') # or '<custom_device>'
1986+
.. code-block:: python
1987+
>>> # doctest: +REQUIRES(env:GPU)
1988+
>>> import paddle
1989+
>>> paddle.device.set_device('gpu') # or '<custom_device>'
19061990
1907-
>>> # paddle.cuda.reset_max_memory_allocated() is equivalent to paddle.device.reset_max_memory_allocated()
1991+
>>> # paddle.cuda.reset_max_memory_allocated() is equivalent to paddle.device.reset_max_memory_allocated()
19081992
1909-
>>> paddle.device.reset_max_memory_allocated(paddle.CUDAPlace(0))
1910-
>>> paddle.device.reset_max_memory_allocated(0)
1911-
>>> paddle.device.reset_max_memory_allocated("gpu:0")
1993+
>>> paddle.device.reset_max_memory_allocated(paddle.CUDAPlace(0))
1994+
>>> paddle.device.reset_max_memory_allocated(0)
1995+
>>> paddle.device.reset_max_memory_allocated("gpu:0")
19121996
"""
19131997
reset_max_memory_allocated()
19141998

python/paddle/framework/random.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,32 @@ def set_random_seed_generator(name: str, seed: int) -> None:
271271

272272
def get_random_seed_generator(name: str) -> paddle.base.core.Generator:
273273
return core.get_random_seed_generator(name)
274+
275+
276+
class Generator:
277+
def __new__(
278+
cls, device: str | int | paddle.core.Place = None
279+
) -> core.Generator:
280+
"""
281+
Generator is a random number generator.
282+
283+
Args:
284+
device(str|int|paddle.core.Place): The device type to create the generator on.
285+
It can be ``cpu``, ``gpu``, ``xpu``, or a paddle.core.Place instance.
286+
default is None, which means using current device.
287+
288+
Examples:
289+
.. code-block:: python
290+
291+
>>> import paddle
292+
>>> g_cpu = paddle.Generator()
293+
"""
294+
place = paddle.device.device_to_place(device)
295+
if isinstance(place, core.CPUPlace):
296+
return core.default_cpu_generator()
297+
elif isinstance(place, core.CUDAPlace):
298+
return core.default_cuda_generator(place.gpu_device_id())
299+
elif isinstance(place, core.XPUPlace):
300+
return core.default_xpu_generator(place.gpu_device_id())
301+
elif isinstance(place, core.CustomPlace):
302+
return core.default_custom_device_generator(place)

test/compat/test_torch_proxy.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,5 @@ def test_use_torch_inside_inner_function(self):
8686
)
8787

8888

89-
class TestTorchOverriddenClass(unittest.TestCase):
90-
def test_overridden_class(self):
91-
self.assertRaises(AttributeError, lambda: paddle.Generator)
92-
with paddle.compat.use_torch_proxy_guard():
93-
import torch
94-
95-
gen = torch.Generator()
96-
97-
9889
if __name__ == "__main__":
9990
unittest.main()

0 commit comments

Comments
 (0)