Skip to content

Commit 8d847fd

Browse files
jonathankingpre-commit-ci[bot]deependujha
authored
fix: raise ValueError when provided seed is out-of-bounds (#21029)
* fix: crash when provided seed is out-of-bounds * fix: update handling of invalid seeds via PL_GLOBAL_SEED * fix: update test for invalid env var seed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: typo * Apply suggestions from code review Co-authored-by: Deependu <[email protected]> * changelog * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deependu <[email protected]>
1 parent e752cb5 commit 8d847fd

File tree

3 files changed

+26
-13
lines changed

3 files changed

+26
-13
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
-
1818

19+
### Changed
20+
21+
- Raise ValueError when seed is `out-of-bounds` or `cannot be cast to int` ([#21029](https://github.com/Lightning-AI/pytorch-lightning/pull/21029))
22+
1923

2024
---
2125

src/lightning/fabric/utilities/seed.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose:
2727
Args:
2828
seed: the integer value seed for global random state in Lightning.
2929
If ``None``, it will read the seed from ``PL_GLOBAL_SEED`` env variable. If ``None`` and the
30-
``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0.
30+
``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0. If seed is
31+
not in bounds or cannot be cast to int, a ValueError is raised.
3132
workers: if set to ``True``, will properly configure all dataloaders passed to the
3233
Trainer with a ``worker_init_fn``. If the user already provides such a function
3334
for their dataloaders, setting this argument will have no influence. See also:
@@ -44,14 +45,12 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose:
4445
try:
4546
seed = int(env_seed)
4647
except ValueError:
47-
seed = 0
48-
rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}")
48+
raise ValueError(f"Invalid seed specified via PL_GLOBAL_SEED: {repr(env_seed)}")
4949
elif not isinstance(seed, int):
5050
seed = int(seed)
5151

5252
if not (min_seed_value <= seed <= max_seed_value):
53-
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
54-
seed = 0
53+
raise ValueError(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
5554

5655
if verbose:
5756
log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))

tests/tests_fabric/utilities/test_seed.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,29 @@ def test_correct_seed_with_environment_variable():
4747

4848
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
4949
def test_invalid_seed():
50-
"""Ensure that we still fix the seed even if an invalid seed is given."""
51-
with pytest.warns(UserWarning, match="Invalid seed found"):
52-
seed = seed_everything()
53-
assert seed == 0
50+
"""Ensure that a ValueError is raised if an invalid seed is given."""
51+
with pytest.raises(ValueError, match="Invalid seed specified"):
52+
seed_everything()
5453

5554

5655
@mock.patch.dict(os.environ, {}, clear=True)
5756
@pytest.mark.parametrize("seed", [10e9, -10e9])
5857
def test_out_of_bounds_seed(seed):
59-
"""Ensure that we still fix the seed even if an out-of-bounds seed is given."""
60-
with pytest.warns(UserWarning, match="is not in bounds"):
61-
actual = seed_everything(seed)
62-
assert actual == 0
58+
"""Ensure that a ValueError is raised if an out-of-bounds seed is given."""
59+
with pytest.raises(ValueError, match="is not in bounds"):
60+
seed_everything(seed)
61+
62+
63+
def test_seed_everything_accepts_valid_seed_argument():
64+
"""Ensure that seed_everything returns the provided valid seed."""
65+
seed_value = 45
66+
assert seed_everything(seed_value) == seed_value
67+
68+
69+
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "17"}, clear=True)
70+
def test_seed_everything_accepts_valid_seed_from_env():
71+
"""Ensure that seed_everything uses the valid seed from the PL_GLOBAL_SEED environment variable."""
72+
assert seed_everything() == 17
6373

6474

6575
def test_reset_seed_no_op():

0 commit comments

Comments
 (0)