diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 4ce80bb6451e3..18537ca15e2fc 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -16,6 +16,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +### Changed + +- Raise ValueError when seed is `out-of-bounds` or `cannot be cast to int` ([#21029](https://github.com/Lightning-AI/pytorch-lightning/pull/21029)) + --- diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index f9c0ddeb86cf0..534e5e3db653e 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -27,7 +27,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose: Args: seed: the integer value seed for global random state in Lightning. If ``None``, it will read the seed from ``PL_GLOBAL_SEED`` env variable. If ``None`` and the - ``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0. + ``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0. If seed is + not in bounds or cannot be cast to int, a ValueError is raised. workers: if set to ``True``, will properly configure all dataloaders passed to the Trainer with a ``worker_init_fn``. If the user already provides such a function for their dataloaders, setting this argument will have no influence. See also: @@ -44,14 +45,12 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose: try: seed = int(env_seed) except ValueError: - seed = 0 - rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}") + raise ValueError(f"Invalid seed specified via PL_GLOBAL_SEED: {repr(env_seed)}") elif not isinstance(seed, int): seed = int(seed) if not (min_seed_value <= seed <= max_seed_value): - rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") - seed = 0 + raise ValueError(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") if verbose: log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank())) diff --git a/tests/tests_fabric/utilities/test_seed.py b/tests/tests_fabric/utilities/test_seed.py index 4a948a5f98736..2700213747f9a 100644 --- a/tests/tests_fabric/utilities/test_seed.py +++ b/tests/tests_fabric/utilities/test_seed.py @@ -47,19 +47,29 @@ def test_correct_seed_with_environment_variable(): @mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True) def test_invalid_seed(): - """Ensure that we still fix the seed even if an invalid seed is given.""" - with pytest.warns(UserWarning, match="Invalid seed found"): - seed = seed_everything() - assert seed == 0 + """Ensure that a ValueError is raised if an invalid seed is given.""" + with pytest.raises(ValueError, match="Invalid seed specified"): + seed_everything() @mock.patch.dict(os.environ, {}, clear=True) @pytest.mark.parametrize("seed", [10e9, -10e9]) def test_out_of_bounds_seed(seed): - """Ensure that we still fix the seed even if an out-of-bounds seed is given.""" - with pytest.warns(UserWarning, match="is not in bounds"): - actual = seed_everything(seed) - assert actual == 0 + """Ensure that a ValueError is raised if an out-of-bounds seed is given.""" + with pytest.raises(ValueError, match="is not in bounds"): + seed_everything(seed) + + +def test_seed_everything_accepts_valid_seed_argument(): + """Ensure that seed_everything returns the provided valid seed.""" + seed_value = 45 + assert seed_everything(seed_value) == seed_value + + +@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "17"}, clear=True) +def test_seed_everything_accepts_valid_seed_from_env(): + """Ensure that seed_everything uses the valid seed from the PL_GLOBAL_SEED environment variable.""" + assert seed_everything() == 17 def test_reset_seed_no_op():