Skip to content

Commit 0d9b235

Browse files
[API compatibility] add paddle.compat.seed method (PaddlePaddle#76440)
* add paddle.compat.seed method * add paddle.compat.seed method * add api doc * add xpu test * add xpu test
1 parent a5bd401 commit 0d9b235

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

python/paddle/compat/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import paddle
2020
from paddle import _C_ops
21+
from paddle.base import core
2122
from paddle.base.framework import Variable
2223
from paddle.framework import (
2324
in_dynamic_mode,
@@ -47,6 +48,7 @@
4748
'max',
4849
'median',
4950
'nanmedian',
51+
'seed',
5052
]
5153

5254

@@ -234,6 +236,22 @@ def nanmedian(
234236
return MedianRetType(values=values, indices=indices)
235237

236238

239+
def seed() -> int:
240+
r"""Sets the seed for generating random numbers to a non-deterministic
241+
random number on all devices. Returns a 64 bit number used to seed the RNG.
242+
Returns:
243+
Returns: int64, the seed used to seed the RNG.
244+
Examples:
245+
.. code-block:: python
246+
247+
>>> import paddle
248+
>>> seed = paddle.compat.seed()
249+
"""
250+
seed = core.default_cpu_generator().seed()
251+
paddle.seed(seed)
252+
return seed
253+
254+
237255
class MinMaxRetType(NamedTuple):
238256
values: Tensor
239257
indices: Tensor
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle
18+
from paddle.base import core
19+
from paddle.compat import seed as compat_seed
20+
21+
22+
class TestCompatSeed(unittest.TestCase):
23+
def test_seed(self):
24+
paddle.seed(42)
25+
seed_cpu_random = core.default_cpu_generator().random()
26+
if paddle.is_compiled_with_cuda():
27+
seed_gpu_random = core.default_cuda_generator(0).random()
28+
if paddle.is_compiled_with_xpu():
29+
seed_xpu_random = core.default_xpu_generator(0).random()
30+
paddle.seed(42)
31+
compat_seed()
32+
compat_seed_cpu_random = core.default_cpu_generator().random()
33+
34+
if paddle.is_compiled_with_cuda():
35+
compat_seed_gpu_random = core.default_cuda_generator(0).random()
36+
assert seed_gpu_random != compat_seed_gpu_random, (
37+
"GPU Random Seed Not Change!"
38+
)
39+
if paddle.is_compiled_with_xpu():
40+
compat_seed_xpu_random = core.default_xpu_generator(0).random()
41+
assert seed_xpu_random != compat_seed_xpu_random, (
42+
"XPU Random Seed Not Change!"
43+
)
44+
45+
assert seed_cpu_random != compat_seed_cpu_random, (
46+
"CPU Random Seed Not Change!"
47+
)
48+
49+
50+
if __name__ == '__main__':
51+
unittest.main()

0 commit comments

Comments
 (0)