Skip to content

Commit 3d0e579

Browse files
authored
Merge pull request #440 from aai-institute/feature/pytest-xdist
Use pytest-xdist (mostly for local testing)
2 parents 43690b0 + 3aa4651 commit 3d0e579

File tree

7 files changed

+44
-27
lines changed

7 files changed

+44
-27
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
- Using pytest-xdist for faster local tests
6+
[PR #440](https://github.com/aai-institute/pyDVL/pull/440)
57
- Implementation of Data-OOB by @BastienZim
68
[PR #426](https://github.com/aai-institute/pyDVL/pull/426),
79
[PR $431](https://github.com/aai-institute/pyDVL/pull/431)

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ pytest-docker==2.0.0
1919
pytest-mock
2020
pytest-timeout
2121
pytest-lazy-fixture
22+
pytest-xdist>=3.3.1
2223
wheel
2324
twine==4.0.2

tests/conftest.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ def seed_numpy(seed=42):
196196
np.random.seed(seed)
197197

198198

199-
@pytest.fixture(scope="session")
200199
def num_workers():
201200
# Run with 2 CPUs inside GitHub actions
202201
if os.getenv("CI"):
@@ -205,9 +204,22 @@ def num_workers():
205204
return max(1, min(available_cpus() - 1, 4))
206205

207206

208-
@pytest.fixture
209-
def n_jobs(num_workers):
210-
return num_workers
207+
@pytest.fixture(scope="session")
208+
def n_jobs():
209+
return num_workers()
210+
211+
212+
def pytest_xdist_auto_num_workers(config) -> Optional[int]:
213+
"""Return the number of workers to use for pytest-xdist.
214+
215+
This is used by pytest-xdist to automatically determine the number of
216+
workers to use. We want to use all available CPUs, but leave one CPU for
217+
the main process.
218+
"""
219+
220+
if config.option.numprocesses == "auto":
221+
return max(1, (available_cpus() - 1) // num_workers())
222+
return None
211223

212224

213225
################################################################################

tests/utils/conftest.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22

33
from pydvl.parallel.config import ParallelConfig
44

5+
from ..conftest import num_workers
6+
57

68
@pytest.fixture(scope="module", params=["joblib", "ray-local", "ray-external"])
7-
def parallel_config(request, num_workers):
9+
def parallel_config(request):
810
if request.param == "joblib":
9-
yield ParallelConfig(backend="joblib", n_cpus_local=num_workers)
11+
yield ParallelConfig(backend="joblib", n_cpus_local=num_workers())
1012
elif request.param == "ray-local":
1113
try:
1214
import ray
1315
except ImportError:
1416
pytest.skip("Ray not installed.")
15-
yield ParallelConfig(backend="ray", n_cpus_local=num_workers)
17+
yield ParallelConfig(backend="ray", n_cpus_local=num_workers())
1618
ray.shutdown()
1719
elif request.param == "ray-external":
1820
try:
@@ -22,10 +24,7 @@ def parallel_config(request, num_workers):
2224
pytest.skip("Ray not installed.")
2325
# Starts a head-node for the cluster.
2426
cluster = Cluster(
25-
initialize_head=True,
26-
head_node_args={
27-
"num_cpus": num_workers,
28-
},
27+
initialize_head=True, head_node_args={"num_cpus": num_workers()}
2928
)
3029
yield ParallelConfig(backend="ray", address=cluster.address)
3130
ray.shutdown()

tests/utils/test_parallel.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,17 @@
1212
from pydvl.parallel.futures import init_executor
1313
from pydvl.utils.types import Seed
1414

15+
from ..conftest import num_workers
1516

16-
def test_effective_n_jobs(parallel_config, num_workers):
17+
18+
def test_effective_n_jobs(parallel_config):
1719
parallel_backend = init_parallel_backend(parallel_config)
1820
assert parallel_backend.effective_n_jobs(1) == 1
19-
assert parallel_backend.effective_n_jobs(4) == min(4, num_workers)
21+
assert parallel_backend.effective_n_jobs(4) == min(4, num_workers())
2022
if parallel_config.address is None:
21-
assert parallel_backend.effective_n_jobs(-1) == num_workers
23+
assert parallel_backend.effective_n_jobs(-1) == num_workers()
2224
else:
23-
assert parallel_backend.effective_n_jobs(-1) == num_workers
25+
assert parallel_backend.effective_n_jobs(-1) == num_workers()
2426

2527
for n_jobs in [-1, 1, 2]:
2628
assert parallel_backend.effective_n_jobs(n_jobs) == effective_n_jobs(
@@ -166,7 +168,7 @@ def test_map_reduce_seeding(parallel_config, seed_1, seed_2, op):
166168
assert op(result_1, result_2)
167169

168170

169-
def test_wrap_function(parallel_config, num_workers):
171+
def test_wrap_function(parallel_config):
170172
if parallel_config.backend != "ray":
171173
pytest.skip("Only makes sense for ray")
172174

@@ -188,8 +190,8 @@ def get_pid():
188190
return os.getpid()
189191

190192
wrapped_func = parallel_backend.wrap(get_pid, num_cpus=1)
191-
pids = parallel_backend.get([wrapped_func() for _ in range(num_workers)])
192-
assert len(set(pids)) == num_workers
193+
pids = parallel_backend.get([wrapped_func() for _ in range(num_workers())])
194+
assert len(set(pids)) == num_workers()
193195

194196

195197
def test_futures_executor_submit(parallel_config):
@@ -205,7 +207,7 @@ def test_futures_executor_map(parallel_config):
205207
assert results == [1, 2, 3]
206208

207209

208-
def test_futures_executor_map_with_max_workers(parallel_config, num_workers):
210+
def test_futures_executor_map_with_max_workers(parallel_config):
209211
if parallel_config.backend != "ray":
210212
pytest.skip("Currently this test only works with Ray")
211213

@@ -215,12 +217,12 @@ def func(_):
215217

216218
start_time = time.monotonic()
217219
with init_executor(config=parallel_config) as executor:
218-
assert executor._max_workers == num_workers
220+
assert executor._max_workers == num_workers()
219221
list(executor.map(func, range(3)))
220222
end_time = time.monotonic()
221223
total_time = end_time - start_time
222-
# We expect the time difference to be > 3 / num_workers, but has to be at least 1
223-
assert total_time > max(1.0, 3 / num_workers)
224+
# We expect the time difference to be > 3 / num_workers(), but has to be at least 1
225+
assert total_time > max(1.0, 3 / num_workers())
224226

225227

226228
def test_future_cancellation(parallel_config):

tests/value/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pydvl.utils.status import Status
1111
from pydvl.value import ValuationResult
1212

13+
from ..conftest import num_workers
1314
from . import polynomial
1415

1516

@@ -122,5 +123,5 @@ def linear_shapley(linear_dataset, scorer, n_jobs):
122123

123124

124125
@pytest.fixture(scope="module")
125-
def parallel_config(num_workers):
126-
yield ParallelConfig(backend="joblib", n_cpus_local=num_workers, wait_timeout=0.1)
126+
def parallel_config():
127+
yield ParallelConfig(backend="joblib", n_cpus_local=num_workers(), wait_timeout=0.1)

tox.ini

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ setenv =
1212
[testenv:base]
1313
description = Tests base modules
1414
commands =
15-
pytest --cov "{envsitepackagesdir}/pydvl" -m "not torch" {posargs}
15+
pytest -n auto --cov "{envsitepackagesdir}/pydvl" -m "not torch" {posargs}
1616

1717
[testenv:torch]
1818
description = Tests modules that rely on pytorch
1919
commands =
20-
pytest --cov "{envsitepackagesdir}/pydvl" -m torch {posargs}
20+
pytest -n auto --cov "{envsitepackagesdir}/pydvl" -m torch {posargs}
2121
extras =
2222
influence
2323

@@ -26,7 +26,7 @@ description = Tests notebooks
2626
setenv =
2727
PYTHONPATH={toxinidir}/notebooks
2828
commands =
29-
pytest notebooks/ --cov "{envsitepackagesdir}/pydvl"
29+
pytest -n auto notebooks/ --cov "{envsitepackagesdir}/pydvl" {posargs}
3030
deps =
3131
{[testenv]deps}
3232
jupyter==1.0.0

0 commit comments

Comments
 (0)