@@ -147,38 +147,29 @@ def reduce_func(x, y):
147147 assert result == 150
148148
149149
150- def test_map_reduce_reproducible (parallel_config , seed ):
151- """
152- Test that the same result is obtained when using the same seed. And that different
153- results are obtained when using different seeds.
154- """
155-
156- map_reduce_job = MapReduceJob (
157- None ,
158- map_func = _sum_of_random_integers ,
159- reduce_func = _mean_func ,
160- config = parallel_config ,
161- )
162- result_1 = map_reduce_job (seed = seed )
163- result_2 = map_reduce_job (seed = seed )
164- assert result_1 == result_2
165-
166-
167- def test_map_reduce_stochastic (parallel_config , seed , seed_alt ):
168- """
169- Test that the same result is obtained when using the same seed. And that different
170- results are obtained when using different seeds.
150+ @pytest .mark .parametrize (
151+ "seed_1, seed_2, op" ,
152+ [
153+ (None , None , operator .ne ),
154+ (None , 42 , operator .ne ),
155+ (42 , None , operator .ne ),
156+ (42 , 42 , operator .eq ),
157+ ],
158+ )
159+ def test_map_reduce_seeding (parallel_config , seed_1 , seed_2 , op ):
160+ """Test that the same result is obtained when using the same seed. And that
161+ different results are obtained when using different seeds.
171162 """
172163
173164 map_reduce_job = MapReduceJob (
174165 None ,
175166 map_func = _sum_of_random_integers ,
176- reduce_func = _mean_func ,
167+ reduce_func = np . mean ,
177168 config = parallel_config ,
178169 )
179- result_1 = map_reduce_job (seed = seed )
180- result_2 = map_reduce_job (seed = seed_alt )
181- assert result_1 != result_2
170+ result_1 = map_reduce_job (seed = seed_1 )
171+ result_2 = map_reduce_job (seed = seed_2 )
172+ assert op ( result_1 , result_2 )
182173
183174
184175def test_wrap_function (parallel_config , num_workers ):
@@ -273,7 +264,3 @@ def _sum_of_random_integers(x: None, seed: Optional[Seed] = None):
273264 rng = np .random .default_rng (seed )
274265 values = rng .integers (0 , rng .integers (10 , 100 ), 10 )
275266 return np .sum (values )
276-
277-
278- def _mean_func (means ):
279- return np .mean (means )
0 commit comments