Skip to content

Commit 3a299b8

Browse files
authored
Merge pull request #14 from grok-ai/feature/tests
Return the seed resolved from the seed_index
2 parents a56bbcc + 989c6b7 commit 3a299b8

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/nn_core/common/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,15 @@ def enforce_tags(tags: Optional[List[str]]) -> List[str]:
9595
return tags
9696

9797

98-
def seed_index_everything(train_cfg: DictConfig, sampling_seed: int = 42) -> None:
98+
def seed_index_everything(train_cfg: DictConfig, sampling_seed: int = 42) -> Optional[int]:
9999
if "seed_index" in train_cfg and train_cfg.seed_index is not None:
100100
seed_index = train_cfg.seed_index
101101
np.random.seed(sampling_seed)
102102
seeds = np.random.randint(np.iinfo(np.int32).max, size=max(42, seed_index + 1))
103103
seed = seeds[seed_index]
104104
seed_everything(seed)
105105
pylogger.info(f"Setting seed {seed} from seeds[{seed_index}]")
106+
return seed
106107
else:
107108
pylogger.warning("The seed has not been set! The reproducibility is not guaranteed.")
109+
return None

0 commit comments

Comments
 (0)