Skip to content

Commit ce77425

Browse files
committed
update
1 parent 47f4ed0 commit ce77425

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

tests/conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,38 @@
1+
import sys
2+
13
import pytest
24

35

46
def assert_raises_with_message(func, expected_msg, *args, **kwargs):
57
with pytest.raises((AssertionError, ValueError, Exception)) as exc_info:
68
func(*args, **kwargs)
79
assert expected_msg in str(exc_info.value)
10+
11+
12+
@pytest.fixture(scope="session")
13+
def dask_client():
14+
"""Create a Dask LocalCluster with restricted resources for distributed tests.
15+
16+
This prevents resource contention and hanging issues in CI environments by:
17+
- Limiting to 1 worker with 2 threads (minimal overhead)
18+
- Setting a 4GB memory limit per worker
19+
- Disabling the dashboard to save resources
20+
- Using processes=False to avoid Nanny issues
21+
"""
22+
# Skip creating client on platforms where distributed tests don't run
23+
if sys.platform == "win32" or sys.version_info <= (3, 9):
24+
yield None
25+
return
26+
27+
from dask.distributed import Client, LocalCluster
28+
29+
with LocalCluster(
30+
n_workers=1,
31+
threads_per_worker=2,
32+
processes=False, # Use threads instead of processes to avoid Nanny issues
33+
memory_limit="4GB",
34+
dashboard_address=None, # Disable dashboard to save resources
35+
silence_logs="error", # Reduce log noise
36+
) as cluster:
37+
with Client(cluster) as client:
38+
yield client

tests/test_evaluation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def test_datasets_evaluate(setup_series, setup_models, setup_metrics):
276276

277277
@pytest.mark.skipif(sys.platform == "win32", reason="Distributed tests are not supported on Windows")
278278
@pytest.mark.skipif(sys.version_info <= (3, 9), reason="Distributed tests are not supported on Python < 3.10")
279-
def test_distributed_evaluate(setup_series):
279+
def test_distributed_evaluate(setup_series, dask_client):
280280
level = [80, 95]
281281
spark = SparkSession.builder.getOrCreate()
282282
spark.sparkContext.setLogLevel("FATAL")

0 commit comments

Comments
 (0)