1212from pydvl .parallel .futures import init_executor
1313from pydvl .utils .types import Seed
1414
15+ from ..conftest import num_workers
1516
16- def test_effective_n_jobs (parallel_config , num_workers ):
17+
18+ def test_effective_n_jobs (parallel_config ):
1719 parallel_backend = init_parallel_backend (parallel_config )
1820 assert parallel_backend .effective_n_jobs (1 ) == 1
19- assert parallel_backend .effective_n_jobs (4 ) == min (4 , num_workers )
21+ assert parallel_backend .effective_n_jobs (4 ) == min (4 , num_workers () )
2022 if parallel_config .address is None :
21- assert parallel_backend .effective_n_jobs (- 1 ) == num_workers
23+ assert parallel_backend .effective_n_jobs (- 1 ) == num_workers ()
2224 else :
23- assert parallel_backend .effective_n_jobs (- 1 ) == num_workers
25+ assert parallel_backend .effective_n_jobs (- 1 ) == num_workers ()
2426
2527 for n_jobs in [- 1 , 1 , 2 ]:
2628 assert parallel_backend .effective_n_jobs (n_jobs ) == effective_n_jobs (
@@ -166,7 +168,7 @@ def test_map_reduce_seeding(parallel_config, seed_1, seed_2, op):
166168 assert op (result_1 , result_2 )
167169
168170
169- def test_wrap_function (parallel_config , num_workers ):
171+ def test_wrap_function (parallel_config ):
170172 if parallel_config .backend != "ray" :
171173 pytest .skip ("Only makes sense for ray" )
172174
@@ -188,8 +190,8 @@ def get_pid():
188190 return os .getpid ()
189191
190192 wrapped_func = parallel_backend .wrap (get_pid , num_cpus = 1 )
191- pids = parallel_backend .get ([wrapped_func () for _ in range (num_workers )])
192- assert len (set (pids )) == num_workers
193+ pids = parallel_backend .get ([wrapped_func () for _ in range (num_workers () )])
194+ assert len (set (pids )) == num_workers ()
193195
194196
195197def test_futures_executor_submit (parallel_config ):
@@ -205,7 +207,7 @@ def test_futures_executor_map(parallel_config):
205207 assert results == [1 , 2 , 3 ]
206208
207209
208- def test_futures_executor_map_with_max_workers (parallel_config , num_workers ):
210+ def test_futures_executor_map_with_max_workers (parallel_config ):
209211 if parallel_config .backend != "ray" :
210212 pytest .skip ("Currently this test only works with Ray" )
211213
@@ -215,12 +217,12 @@ def func(_):
215217
216218 start_time = time .monotonic ()
217219 with init_executor (config = parallel_config ) as executor :
218- assert executor ._max_workers == num_workers
220+ assert executor ._max_workers == num_workers ()
219221 list (executor .map (func , range (3 )))
220222 end_time = time .monotonic ()
221223 total_time = end_time - start_time
222- # We expect the time difference to be > 3 / num_workers, but has to be at least 1
223- assert total_time > max (1.0 , 3 / num_workers )
224+ # We expect the time difference to be > 3 / num_workers() , but has to be at least 1
225+ assert total_time > max (1.0 , 3 / num_workers () )
224226
225227
226228def test_future_cancellation (parallel_config ):
0 commit comments