Skip to content

Commit 3700b51

Browse files
committed
Import guard k8s import in Ray Cluster and Job
Signed-off-by: Hemil Desai <[email protected]>
1 parent 252edfb commit 3700b51

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

nemo_run/run/ray/cluster.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,43 @@
1414
# limitations under the License.
1515

1616
from dataclasses import dataclass
17-
from typing import Optional
17+
from typing import Optional, Type
1818

1919
from nemo_run.core.execution.base import Executor
20-
from nemo_run.core.execution.kuberay import KubeRayExecutor
2120
from nemo_run.core.execution.slurm import SlurmExecutor
22-
from nemo_run.run.ray.kuberay import KubeRayCluster
2321
from nemo_run.run.ray.slurm import SlurmRayCluster
2422

23+
# Import guard for Kubernetes dependencies
24+
try:
25+
from nemo_run.core.execution.kuberay import KubeRayExecutor
26+
from nemo_run.run.ray.kuberay import KubeRayCluster
27+
28+
_KUBERAY_AVAILABLE = True
29+
except ImportError:
30+
KubeRayExecutor = None
31+
KubeRayCluster = None
32+
_KUBERAY_AVAILABLE = False
33+
2534
USE_WITH_RAY_CLUSTER_KEY = "use_with_ray_cluster"
2635

2736

2837
@dataclass(kw_only=True)
2938
class RayCluster:
30-
BACKEND_MAP = {
31-
KubeRayExecutor: KubeRayCluster,
32-
SlurmExecutor: SlurmRayCluster,
33-
}
34-
3539
name: str
3640
executor: Executor
3741

3842
def __post_init__(self):
39-
if self.executor.__class__ not in self.BACKEND_MAP:
43+
backend_map: dict[Type[Executor], Type] = {
44+
SlurmExecutor: SlurmRayCluster,
45+
}
46+
47+
if _KUBERAY_AVAILABLE and KubeRayExecutor is not None and KubeRayCluster is not None:
48+
backend_map[KubeRayExecutor] = KubeRayCluster
49+
50+
if self.executor.__class__ not in backend_map:
4051
raise ValueError(f"Unsupported executor: {self.executor.__class__}")
4152

42-
backend_cls = self.BACKEND_MAP[self.executor.__class__]
53+
backend_cls = backend_map[self.executor.__class__]
4354
self.backend = backend_cls(name=self.name, executor=self.executor) # type: ignore[arg-type]
4455

4556
self._port_forward_map = {}

nemo_run/run/ray/job.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,45 @@
1414
# limitations under the License.
1515

1616
from dataclasses import dataclass
17-
from typing import Any, Optional
17+
from typing import Any, Optional, Type
1818

1919
from nemo_run.core.execution.base import Executor
20-
from nemo_run.core.execution.kuberay import KubeRayExecutor
2120
from nemo_run.core.execution.slurm import SlurmExecutor
22-
from nemo_run.run.ray.kuberay import KubeRayJob
2321
from nemo_run.run.ray.slurm import SlurmRayJob
2422

23+
# Import guard for Kubernetes dependencies
24+
try:
25+
from nemo_run.core.execution.kuberay import KubeRayExecutor
26+
from nemo_run.run.ray.kuberay import KubeRayJob
27+
28+
_KUBERAY_AVAILABLE = True
29+
except ImportError:
30+
KubeRayExecutor = None
31+
KubeRayJob = None
32+
_KUBERAY_AVAILABLE = False
33+
2534

2635
@dataclass(kw_only=True)
2736
class RayJob:
2837
"""Backend-agnostic convenience wrapper around Ray *jobs*."""
2938

30-
BACKEND_MAP = {
31-
KubeRayExecutor: KubeRayJob,
32-
SlurmExecutor: SlurmRayJob,
33-
}
34-
3539
name: str
3640
executor: Executor
3741
pre_ray_start_commands: Optional[list[str]] = None
3842

3943
def __post_init__(self) -> None: # noqa: D401 – simple implementation
40-
if self.executor.__class__ not in self.BACKEND_MAP:
44+
backend_map: dict[Type[Executor], Type[Any]] = {
45+
SlurmExecutor: SlurmRayJob,
46+
}
47+
48+
if _KUBERAY_AVAILABLE and KubeRayExecutor is not None and KubeRayJob is not None:
49+
backend_map[KubeRayExecutor] = KubeRayJob
50+
51+
if self.executor.__class__ not in backend_map:
4152
raise ValueError(f"Unsupported executor: {self.executor.__class__}")
4253

43-
self.backend = self.BACKEND_MAP[self.executor.__class__](
44-
name=self.name, executor=self.executor
45-
)
54+
backend_cls = backend_map[self.executor.__class__]
55+
self.backend = backend_cls(name=self.name, executor=self.executor)
4656

4757
# ------------------------------------------------------------------
4858
# Public API

0 commit comments

Comments
 (0)