@@ -47,19 +47,29 @@ def test_correct_seed_with_environment_variable():
4747
4848@mock .patch .dict (os .environ , {"PL_GLOBAL_SEED" : "invalid" }, clear = True )
4949def test_invalid_seed ():
50- """Ensure that we still fix the seed even if an invalid seed is given."""
51- with pytest .warns (UserWarning , match = "Invalid seed found" ):
52- seed = seed_everything ()
53- assert seed == 0
50+ """Ensure that a ValueError is raised if an invalid seed is given."""
51+ with pytest .raises (ValueError , match = "Invalid seed specified" ):
52+ seed_everything ()
5453
5554
5655@mock .patch .dict (os .environ , {}, clear = True )
5756@pytest .mark .parametrize ("seed" , [10e9 , - 10e9 ])
5857def test_out_of_bounds_seed (seed ):
59- """Ensure that we still fix the seed even if an out-of-bounds seed is given."""
60- with pytest .warns (UserWarning , match = "is not in bounds" ):
61- actual = seed_everything (seed )
62- assert actual == 0
58+ """Ensure that a ValueError is raised if an out-of-bounds seed is given."""
59+ with pytest .raises (ValueError , match = "is not in bounds" ):
60+ seed_everything (seed )
61+
62+
63+ def test_seed_everything_accepts_valid_seed_argument ():
64+ """Ensure that seed_everything returns the provided valid seed."""
65+ seed_value = 45
66+ assert seed_everything (seed_value ) == seed_value
67+
68+
69+ @mock .patch .dict (os .environ , {"PL_GLOBAL_SEED" : "17" }, clear = True )
70+ def test_seed_everything_accepts_valid_seed_from_env ():
71+ """Ensure that seed_everything uses the valid seed from the PL_GLOBAL_SEED environment variable."""
72+ assert seed_everything () == 17
6373
6474
6575def test_reset_seed_no_op ():
0 commit comments