Skip to content

Commit e31fbe9

Browse files
committed
add testing
1 parent 45c8463 commit e31fbe9

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

tests/tests_fabric/utilities/test_seed.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from lightning.fabric.utilities.seed import (
1212
_collect_rng_states,
13+
_generate_seed_sequence,
1314
_set_rng_states,
1415
pl_worker_init_function,
1516
reset_seed,
@@ -153,3 +154,23 @@ def test_pl_worker_init_function(base_seed, num_workers, num_ranks):
153154
assert len(stdlib_rands) == num_ranks * num_workers
154155
assert len(numpy_rands) == num_ranks * num_workers
155156
assert len(torch_rands | stdlib_rands | numpy_rands) == 3 * num_workers * num_ranks
157+
158+
159+
def test_generate_seed_sequence_no_collision():
160+
"""Test that _generate_seed_sequence produces unique seeds for different base seeds."""
161+
base_seeds = [0, 1, 42, 123, 999, 12345]
162+
generated_seeds = []
163+
random_outputs = []
164+
165+
for base_seed in base_seeds:
166+
seed_everything(base_seed)
167+
process_seed = torch.initial_seed()
168+
generated_seed = _generate_seed_sequence(process_seed, worker_id=0, global_rank=0, count=1)[0]
169+
generated_seeds.append(generated_seed)
170+
torch.manual_seed(generated_seed)
171+
random_outputs.append(tuple(torch.randn(10).tolist()))
172+
173+
assert len(set(generated_seeds)) == len(generated_seeds), (
174+
"Generated seeds should be unique for different base seeds"
175+
)
176+
assert len(set(random_outputs)) == len(random_outputs), "Random outputs should be unique for different base seeds"

0 commit comments

Comments
 (0)