|
6 | 6 | This interface might be deprecated or changed in a future release before 1.0 |
7 | 7 |
|
8 | 8 | """ |
| 9 | +import warnings |
9 | 10 | from functools import reduce |
10 | 11 | from itertools import accumulate, repeat |
11 | 12 | from typing import Any, Collection, Dict, Generic, List, Optional, TypeVar, Union |
|
18 | 19 | from ..utils.functional import maybe_add_argument |
19 | 20 | from ..utils.types import MapFunction, ReduceFunction, Seed, ensure_seed_sequence |
20 | 21 | from .backend import ParallelBackend, _maybe_init_parallel_backend |
21 | | -from .backends import JoblibParallelBackend |
22 | 22 | from .config import ParallelConfig |
23 | 23 |
|
24 | 24 | __all__ = ["MapReduceJob"] |
@@ -109,14 +109,6 @@ def __init__( |
109 | 109 | ): |
110 | 110 | parallel_backend = _maybe_init_parallel_backend(parallel_backend, config) |
111 | 111 |
|
112 | | - if not isinstance(parallel_backend, JoblibParallelBackend): |
113 | | - raise ValueError( |
114 | | - f"Unexpected parallel backend {parallel_backend.__class__.__name__}. " |
115 | | - "MapReduceJob only supports the use of JoblibParallelBackend " |
116 | | - "with passing the specific" |
117 | | - "joblib backend name using `joblib.parallel_config`. " |
118 | | - ) |
119 | | - |
120 | 112 | self.parallel_backend = parallel_backend |
121 | 113 |
|
122 | 114 | self.timeout = timeout |
@@ -149,7 +141,19 @@ def __call__( |
149 | 141 | """ |
150 | 142 | seed_seq = ensure_seed_sequence(seed) |
151 | 143 |
|
152 | | - with Parallel(prefer="processes") as parallel: |
| 144 | + if hasattr(self.parallel_backend, "_joblib_backend_name"): |
| 145 | + backend = getattr(self.parallel_backend, "_joblib_backend_name") |
| 146 | + else: |
| 147 | + warnings.warn( |
| 148 | + "Parallel backend " |
| 149 | + f"{self.parallel_backend.__class__.__name__}. " |
| 150 | + "should have a `_joblib_backend_name` attribute in order to work " |
| 151 | + "property with MapReduceJob. " |
| 152 | + "Defaulting to joblib loky backend" |
| 153 | + ) |
| 154 | + backend = "loky" |
| 155 | + |
| 156 | + with Parallel(backend=backend, prefer="processes") as parallel: |
153 | 157 | chunks = self._chunkify(self.inputs_, n_chunks=self.n_jobs) |
154 | 158 | map_results: List[R] = parallel( |
155 | 159 | delayed(self._map_func)( |
|
0 commit comments