Skip to content

Commit 2ba97e7

Browse files
committed
Remove unnecessary function and merge tests
1 parent b827439 commit 2ba97e7

File tree

1 file changed

+16
-29
lines changed

1 file changed

+16
-29
lines changed

tests/utils/test_parallel.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -147,38 +147,29 @@ def reduce_func(x, y):
147147
assert result == 150
148148

149149

150-
def test_map_reduce_reproducible(parallel_config, seed):
151-
"""
152-
Test that the same result is obtained when using the same seed. And that different
153-
results are obtained when using different seeds.
154-
"""
155-
156-
map_reduce_job = MapReduceJob(
157-
None,
158-
map_func=_sum_of_random_integers,
159-
reduce_func=_mean_func,
160-
config=parallel_config,
161-
)
162-
result_1 = map_reduce_job(seed=seed)
163-
result_2 = map_reduce_job(seed=seed)
164-
assert result_1 == result_2
165-
166-
167-
def test_map_reduce_stochastic(parallel_config, seed, seed_alt):
168-
"""
169-
Test that the same result is obtained when using the same seed. And that different
170-
results are obtained when using different seeds.
150+
@pytest.mark.parametrize(
151+
"seed_1, seed_2, op",
152+
[
153+
(None, None, operator.ne),
154+
(None, 42, operator.ne),
155+
(42, None, operator.ne),
156+
(42, 42, operator.eq),
157+
],
158+
)
159+
def test_map_reduce_seeding(parallel_config, seed_1, seed_2, op):
160+
"""Test that the same result is obtained when using the same seed. And that
161+
different results are obtained when using different seeds.
171162
"""
172163

173164
map_reduce_job = MapReduceJob(
174165
None,
175166
map_func=_sum_of_random_integers,
176-
reduce_func=_mean_func,
167+
reduce_func=np.mean,
177168
config=parallel_config,
178169
)
179-
result_1 = map_reduce_job(seed=seed)
180-
result_2 = map_reduce_job(seed=seed_alt)
181-
assert result_1 != result_2
170+
result_1 = map_reduce_job(seed=seed_1)
171+
result_2 = map_reduce_job(seed=seed_2)
172+
assert op(result_1, result_2)
182173

183174

184175
def test_wrap_function(parallel_config, num_workers):
@@ -273,7 +264,3 @@ def _sum_of_random_integers(x: None, seed: Optional[Seed] = None):
273264
rng = np.random.default_rng(seed)
274265
values = rng.integers(0, rng.integers(10, 100), 10)
275266
return np.sum(values)
276-
277-
278-
def _mean_func(means):
279-
return np.mean(means)

0 commit comments

Comments
 (0)