Skip to content

Commit 55e06af

Browse files
Define _joblib_backend_name attribute for parallel backend classes and use it in MapReduceJob
1 parent 6a7313c commit 55e06af

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

src/pydvl/parallel/backends/joblib.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ class JoblibParallelBackend(ParallelBackend, backend_name="joblib"):
3030
```
3131
"""
3232

33+
_joblib_backend_name: str = "loky"
34+
"""Name of the backend to use for joblib inside [MapReduceJob][pydvl.parallel.mapreduce.MapReduceJob]."""
35+
3336
@deprecated(
3437
target=True,
3538
args_mapping={"config": None},

src/pydvl/parallel/backends/ray.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ class RayParallelBackend(ParallelBackend, backend_name="ray"):
3030
```
3131
"""
3232

33+
_joblib_backend_name: str = "ray"
34+
"""Name of the backend to use for joblib inside [MapReduceJob][pydvl.parallel.mapreduce.MapReduceJob]."""
35+
3336
@deprecated(
3437
target=True,
3538
args_mapping={"config": None},

src/pydvl/parallel/map_reduce.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
This interface might be deprecated or changed in a future release before 1.0
77
88
"""
9+
import warnings
910
from functools import reduce
1011
from itertools import accumulate, repeat
1112
from typing import Any, Collection, Dict, Generic, List, Optional, TypeVar, Union
@@ -18,7 +19,6 @@
1819
from ..utils.functional import maybe_add_argument
1920
from ..utils.types import MapFunction, ReduceFunction, Seed, ensure_seed_sequence
2021
from .backend import ParallelBackend, _maybe_init_parallel_backend
21-
from .backends import JoblibParallelBackend
2222
from .config import ParallelConfig
2323

2424
__all__ = ["MapReduceJob"]
@@ -109,14 +109,6 @@ def __init__(
109109
):
110110
parallel_backend = _maybe_init_parallel_backend(parallel_backend, config)
111111

112-
if not isinstance(parallel_backend, JoblibParallelBackend):
113-
raise ValueError(
114-
f"Unexpected parallel backend {parallel_backend.__class__.__name__}. "
115-
"MapReduceJob only supports the use of JoblibParallelBackend "
116-
"with passing the specific"
117-
"joblib backend name using `joblib.parallel_config`. "
118-
)
119-
120112
self.parallel_backend = parallel_backend
121113

122114
self.timeout = timeout
@@ -149,7 +141,19 @@ def __call__(
149141
"""
150142
seed_seq = ensure_seed_sequence(seed)
151143

152-
with Parallel(prefer="processes") as parallel:
144+
if hasattr(self.parallel_backend, "_joblib_backend_name"):
145+
backend = getattr(self.parallel_backend, "_joblib_backend_name")
146+
else:
147+
warnings.warn(
148+
"Parallel backend "
149+
f"{self.parallel_backend.__class__.__name__}. "
150+
"should have a `_joblib_backend_name` attribute in order to work "
151+
"property with MapReduceJob. "
152+
"Defaulting to joblib loky backend"
153+
)
154+
backend = "loky"
155+
156+
with Parallel(backend=backend, prefer="processes") as parallel:
153157
chunks = self._chunkify(self.inputs_, n_chunks=self.n_jobs)
154158
map_results: List[R] = parallel(
155159
delayed(self._map_func)(

0 commit comments

Comments
 (0)