Skip to content

Commit ef6ee61

Browse files
authored
Update seed.py
1 parent de1a493 commit ef6ee61

File tree

1 file changed

+5
-2
lines changed
  • src/lightning/fabric/utilities

1 file changed

+5
-2
lines changed

src/lightning/fabric/utilities/seed.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
9595
env_seed = os.environ.get("PL_GLOBAL_SEED", None)
9696
if env_seed is None:
9797
env_seed = "0"
98-
rank_zero_warn(f"No seed found, seed set to {env_seed}")
98+
rank_zero_warn(f"No seed found, worker seed set to {env_seed}")
9999
process_seed = int(env_seed)
100100
# back out the base seed so we can use all the bits
101101
base_seed = process_seed - worker_id
@@ -108,7 +108,10 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
108108
if _NUMPY_AVAILABLE:
109109
import numpy as np
110110

111-
np.random.seed(seed_sequence[3] & 0xFFFFFFFF) # numpy takes 32-bit seed only
111+
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
112+
np_rng_seed = ss.generate_state(4)
113+
114+
np.random.seed(np_rng_seed)
112115

113116

114117
def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> List[int]:

0 commit comments

Comments
 (0)