Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

-

### Changed

- Raise ValueError when seed is `out-of-bounds` or `cannot be cast to int` ([#21029](https://github.com/Lightning-AI/pytorch-lightning/pull/21029))


---

Expand Down
9 changes: 4 additions & 5 deletions src/lightning/fabric/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose:
Args:
seed: the integer value seed for global random state in Lightning.
If ``None``, it will read the seed from ``PL_GLOBAL_SEED`` env variable. If ``None`` and the
``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0.
``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0. If seed is
not in bounds or cannot be cast to int, a ValueError is raised.
workers: if set to ``True``, will properly configure all dataloaders passed to the
Trainer with a ``worker_init_fn``. If the user already provides such a function
for their dataloaders, setting this argument will have no influence. See also:
Expand All @@ -44,14 +45,12 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose:
try:
seed = int(env_seed)
except ValueError:
seed = 0
rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}")
raise ValueError(f"Invalid seed specified via PL_GLOBAL_SEED: {repr(env_seed)}")
elif not isinstance(seed, int):
seed = int(seed)

if not (min_seed_value <= seed <= max_seed_value):
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
seed = 0
raise ValueError(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")

if verbose:
log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))
Expand Down
26 changes: 18 additions & 8 deletions tests/tests_fabric/utilities/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,29 @@ def test_correct_seed_with_environment_variable():

@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
def test_invalid_seed():
"""Ensure that we still fix the seed even if an invalid seed is given."""
with pytest.warns(UserWarning, match="Invalid seed found"):
seed = seed_everything()
assert seed == 0
"""Ensure that a ValueError is raised if an invalid seed is given."""
with pytest.raises(ValueError, match="Invalid seed specified"):
seed_everything()


@mock.patch.dict(os.environ, {}, clear=True)
@pytest.mark.parametrize("seed", [10e9, -10e9])
def test_out_of_bounds_seed(seed):
"""Ensure that we still fix the seed even if an out-of-bounds seed is given."""
with pytest.warns(UserWarning, match="is not in bounds"):
actual = seed_everything(seed)
assert actual == 0
"""Ensure that a ValueError is raised if an out-of-bounds seed is given."""
with pytest.raises(ValueError, match="is not in bounds"):
seed_everything(seed)


def test_seed_everything_accepts_valid_seed_argument():
"""Ensure that seed_everything returns the provided valid seed."""
seed_value = 45
assert seed_everything(seed_value) == seed_value


@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "17"}, clear=True)
def test_seed_everything_accepts_valid_seed_from_env():
"""Ensure that seed_everything uses the valid seed from the PL_GLOBAL_SEED environment variable."""
assert seed_everything() == 17


def test_reset_seed_no_op():
Expand Down
Loading