Skip to content

Commit 49772c3

Browse files
authored
Fix _generate_seed_sequence sampling (#21399)
1 parent 79ffe50 commit 49772c3

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

src/lightning/fabric/utilities/seed.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,19 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
111111

112112

113113
def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> list[int]:
114-
"""Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG)
115-
algorithm."""
114+
"""Generates a sequence of seeds from a base seed, worker id and rank using hash-based mixing followed by the
115+
linear congruential generator (LCG) algorithm."""
116116
# Combine base seed, worker id and rank into a unique 64-bit number
117117
combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank
118+
119+
# Apply hash-based mixing (MurmurHash3 finalizer) to distribute bits uniformly
120+
# This ensures that small base seeds don't result in zeros in lower bits
121+
combined_seed ^= combined_seed >> 33
122+
combined_seed = (combined_seed * 0xFF51AFD7ED558CCD) & ((1 << 64) - 1)
123+
combined_seed ^= combined_seed >> 33
124+
combined_seed = (combined_seed * 0xC4CEB9FE1A85EC53) & ((1 << 64) - 1)
125+
combined_seed ^= combined_seed >> 33
126+
118127
seeds = []
119128
for _ in range(count):
120129
# x_(n+1) = (a * x_n + c) mod m. With c=1, m=2^64 and a is D. Knuth's constant

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020

2121
-
2222

23+
### Fixed
24+
25+
- Fix `_generate_seed_sequence_sampling` function not producing unique seeds ([#21399](https://github.com/Lightning-AI/pytorch-lightning/pull/21399))
26+
2327

2428
## [2.6.0] - 2025-11-28
2529

tests/tests_fabric/utilities/test_seed.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from lightning.fabric.utilities.seed import (
1212
_collect_rng_states,
13+
_generate_seed_sequence,
1314
_set_rng_states,
1415
pl_worker_init_function,
1516
reset_seed,
@@ -153,3 +154,23 @@ def test_pl_worker_init_function(base_seed, num_workers, num_ranks):
153154
assert len(stdlib_rands) == num_ranks * num_workers
154155
assert len(numpy_rands) == num_ranks * num_workers
155156
assert len(torch_rands | stdlib_rands | numpy_rands) == 3 * num_workers * num_ranks
157+
158+
159+
def test_generate_seed_sequence_no_collision():
160+
"""Test that _generate_seed_sequence produces unique seeds for different base seeds."""
161+
base_seeds = [0, 1, 42, 123, 999, 12345]
162+
generated_seeds = []
163+
random_outputs = []
164+
165+
for base_seed in base_seeds:
166+
seed_everything(base_seed)
167+
process_seed = torch.initial_seed()
168+
generated_seed = _generate_seed_sequence(process_seed, worker_id=0, global_rank=0, count=1)[0]
169+
generated_seeds.append(generated_seed)
170+
torch.manual_seed(generated_seed)
171+
random_outputs.append(tuple(torch.randn(10).tolist()))
172+
173+
assert len(set(generated_seeds)) == len(generated_seeds), (
174+
"Generated seeds should be unique for different base seeds"
175+
)
176+
assert len(set(random_outputs)) == len(random_outputs), "Random outputs should be unique for different base seeds"

0 commit comments

Comments
 (0)