|
1 | | -import functools |
2 | 1 | import os |
3 | 2 | from abc import ABCMeta, abstractmethod |
4 | 3 | from dataclasses import asdict |
5 | | -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union |
| 4 | +from typing import ( |
| 5 | + Any, |
| 6 | + Callable, |
| 7 | + Dict, |
| 8 | + Iterable, |
| 9 | + List, |
| 10 | + Optional, |
| 11 | + Tuple, |
| 12 | + Type, |
| 13 | + TypeVar, |
| 14 | + Union, |
| 15 | +) |
6 | 16 |
|
7 | 17 | import ray |
8 | 18 | from ray import ObjectRef |
9 | 19 | from ray.remote_function import RemoteFunction |
10 | 20 |
|
11 | 21 | from ..config import ParallelConfig |
12 | 22 |
|
13 | | -__all__ = [ |
14 | | - "init_parallel_backend", |
15 | | -] |
| 23 | +__all__ = ["init_parallel_backend", "effective_n_jobs", "available_cpus"] |
16 | 24 |
|
17 | 25 | T = TypeVar("T") |
18 | 26 |
|
@@ -63,7 +71,7 @@ def put(self, v: Any, *args, **kwargs) -> Any: |
63 | 71 | ... |
64 | 72 |
|
65 | 73 | @abstractmethod |
66 | | - def wrap(self, *args, **kwargs) -> Any: |
| 74 | + def wrap(self, fun: Callable, **kwargs) -> Callable: |
67 | 75 | ... |
68 | 76 |
|
69 | 77 | @abstractmethod |
@@ -104,9 +112,11 @@ def get(self, v: Any, *args, **kwargs): |
104 | 112 | def put(self, v: Any, *args, **kwargs) -> Any: |
105 | 113 | return v |
106 | 114 |
|
107 | | - def wrap(self, *args, **kwargs) -> Any: |
108 | | - assert len(args) == 1 |
109 | | - return functools.partial(args[0], **kwargs) |
| 115 | + def wrap(self, fun: Callable, **kwargs) -> Callable: |
| 116 | + """Wraps a function for sequential execution. |
| 117 | +
|
| 118 | + This is a noop and kwargs are ignored.""" |
| 119 | + return fun |
110 | 120 |
|
111 | 121 | def wait(self, v: Any, *args, **kwargs) -> Tuple[list, list]: |
112 | 122 | return v, [] |
@@ -151,8 +161,17 @@ def put(self, v: T, *args, **kwargs) -> Union["ObjectRef[T]", T]: |
151 | 161 | except TypeError: |
152 | 162 | return v # type: ignore |
153 | 163 |
|
154 | | - def wrap(self, *args, **kwargs) -> RemoteFunction: |
155 | | - return ray.remote(*args, **kwargs) # type: ignore |
| 164 | + def wrap(self, fun: Callable, **kwargs) -> Callable: |
| 165 | + """Wraps a function as a ray remote. |
| 166 | +
|
| 167 | + :param fun: the function to wrap |
| 168 | + :param kwargs: keyword arguments to pass to @ray.remote |
| 169 | +
|
| 170 | + :return: The `.remote` method of the ray `RemoteFunction`. |
| 171 | + """ |
| 172 | + if len(kwargs) > 1: |
| 173 | + return ray.remote(**kwargs)(fun).remote # type: ignore |
| 174 | + return ray.remote(fun).remote # type: ignore |
156 | 175 |
|
157 | 176 | def wait( |
158 | 177 | self, |
@@ -213,3 +232,25 @@ def available_cpus() -> int: |
213 | 232 | if system() != "Linux": |
214 | 233 | return os.cpu_count() or 1 |
215 | 234 | return len(os.sched_getaffinity(0)) |
| 235 | + |
| 236 | + |
| 237 | +def effective_n_jobs(n_jobs: int, config: ParallelConfig = ParallelConfig()) -> int: |
| 238 | + """Returns the effective number of jobs. |
| 239 | +
|
| 240 | + This number may vary depending on the parallel backend and the resources |
| 241 | + available. |
| 242 | +
|
| 243 | + :param n_jobs: the number of jobs requested. If -1, the number of available |
| 244 | + CPUs is returned. |
| 245 | + :param config: instance of :class:`~pydvl.utils.config.ParallelConfig` with |
| 246 | + cluster address, number of cpus, etc. |
| 247 | + :return: the effective number of jobs, guaranteed to be >= 1. |
| 248 | + :raises RuntimeError: if the effective number of jobs returned by the backend |
| 249 | + is < 1. |
| 250 | + """ |
| 251 | + parallel_backend = init_parallel_backend(config) |
| 252 | + if (eff_n_jobs := parallel_backend.effective_n_jobs(n_jobs)) < 1: |
| 253 | + raise RuntimeError( |
| 254 | + f"Invalid number of jobs {eff_n_jobs} obtained from parallel backend {config.backend}" |
| 255 | + ) |
| 256 | + return eff_n_jobs |
0 commit comments