|
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | 16 | from dataclasses import dataclass |
17 | | -from typing import Any, Optional |
| 17 | +from typing import Any, Optional, Type |
18 | 18 |
|
19 | 19 | from nemo_run.core.execution.base import Executor |
20 | | -from nemo_run.core.execution.kuberay import KubeRayExecutor |
21 | 20 | from nemo_run.core.execution.slurm import SlurmExecutor |
22 | | -from nemo_run.run.ray.kuberay import KubeRayJob |
23 | 21 | from nemo_run.run.ray.slurm import SlurmRayJob |
24 | 22 |
|
| 23 | +# Import guard for Kubernetes dependencies |
| 24 | +try: |
| 25 | + from nemo_run.core.execution.kuberay import KubeRayExecutor |
| 26 | + from nemo_run.run.ray.kuberay import KubeRayJob |
| 27 | + |
| 28 | + _KUBERAY_AVAILABLE = True |
| 29 | +except ImportError: |
| 30 | + KubeRayExecutor = None |
| 31 | + KubeRayJob = None |
| 32 | + _KUBERAY_AVAILABLE = False |
| 33 | + |
25 | 34 |
|
26 | 35 | @dataclass(kw_only=True) |
27 | 36 | class RayJob: |
28 | 37 | """Backend-agnostic convenience wrapper around Ray *jobs*.""" |
29 | 38 |
|
30 | | - BACKEND_MAP = { |
31 | | - KubeRayExecutor: KubeRayJob, |
32 | | - SlurmExecutor: SlurmRayJob, |
33 | | - } |
34 | | - |
35 | 39 | name: str |
36 | 40 | executor: Executor |
37 | 41 | pre_ray_start_commands: Optional[list[str]] = None |
38 | 42 |
|
39 | 43 | def __post_init__(self) -> None: # noqa: D401 – simple implementation |
40 | | - if self.executor.__class__ not in self.BACKEND_MAP: |
| 44 | + backend_map: dict[Type[Executor], Type[Any]] = { |
| 45 | + SlurmExecutor: SlurmRayJob, |
| 46 | + } |
| 47 | + |
| 48 | + if _KUBERAY_AVAILABLE and KubeRayExecutor is not None and KubeRayJob is not None: |
| 49 | + backend_map[KubeRayExecutor] = KubeRayJob |
| 50 | + |
| 51 | + if self.executor.__class__ not in backend_map: |
41 | 52 | raise ValueError(f"Unsupported executor: {self.executor.__class__}") |
42 | 53 |
|
43 | | - self.backend = self.BACKEND_MAP[self.executor.__class__]( |
44 | | - name=self.name, executor=self.executor |
45 | | - ) |
| 54 | + backend_cls = backend_map[self.executor.__class__] |
| 55 | + self.backend = backend_cls(name=self.name, executor=self.executor) |
46 | 56 |
|
47 | 57 | # ------------------------------------------------------------------ |
48 | 58 | # Public API |
|
0 commit comments