Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/lightning/fabric/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose:
Args:
seed: the integer value seed for global random state in Lightning.
If ``None``, it will read the seed from ``PL_GLOBAL_SEED`` env variable. If ``None`` and the
``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0.
``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0. If the seed is provided
as an integer but is not in bounds, a ValueError is raised.
workers: if set to ``True``, will properly configure all dataloaders passed to the
Trainer with a ``worker_init_fn``. If the user already provides such a function
for their dataloaders, setting this argument will have no influence. See also:
Expand All @@ -50,8 +51,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose:
seed = int(seed)

if not (min_seed_value <= seed <= max_seed_value):
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
seed = 0
raise ValueError(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")

if verbose:
log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))
Expand Down
7 changes: 3 additions & 4 deletions tests/tests_fabric/utilities/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,9 @@ def test_invalid_seed():
@mock.patch.dict(os.environ, {}, clear=True)
@pytest.mark.parametrize("seed", [10e9, -10e9])
def test_out_of_bounds_seed(seed):
"""Ensure that we still fix the seed even if an out-of-bounds seed is given."""
with pytest.warns(UserWarning, match="is not in bounds"):
actual = seed_everything(seed)
assert actual == 0
"""Ensure that a ValueError is raised if an out-of-bounds seed is given."""
with pytest.raises(ValueError, match="is not in bounds"):
seed_everything(seed)


def test_reset_seed_no_op():
Expand Down
Loading