Skip to content

Commit 39b3770

Browse files
committed
Add error handling to LeptonExecutor
Handle more possible failure scenarios for the LeptonExecutor where the code could run into a bad state and the user should be alerted with helpful debug info. Signed-Off-By: Robert Clark <[email protected]>
1 parent 0fd39ed commit 39b3770

File tree

1 file changed

+41
-20
lines changed

1 file changed

+41
-20
lines changed

nemo_run/core/execution/lepton.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import time
66
from dataclasses import dataclass, field
77
from pathlib import Path
8-
from typing import Any, Optional, Type
8+
from typing import Any, Optional, Set, Type
99

1010
from invoke.context import Context
1111
from leptonai.api.v1.client import APIClient
1212
from leptonai.api.v1.types.affinity import LeptonResourceAffinity
1313
from leptonai.api.v1.types.common import Metadata
14+
from leptonai.api.v1.types.dedicated_node_group import DedicatedNodeGroup
1415
from leptonai.api.v1.types.deployment import EnvVar, LeptonContainer, Mount
1516
from leptonai.api.v1.types.job import (LeptonJob, LeptonJobState,
1617
LeptonJobUserSpec)
@@ -86,17 +87,32 @@ def move_data(self, sleep: float = 10) -> None:
8687
remote_path=relative_path
8788
)
8889

89-
def setup_distributed_pytorch(self) -> str:
90+
def _node_group_id(self, client: APIClient) -> DedicatedNodeGroup:
9091
"""
91-
Runs a custom script from Lepton to setup the distributed PyTorch
92-
environment variables required for distributed PyTorch jobs.
92+
Find the node group ID for the passed node group.
93+
94+
Lists all node groups available to the user and matches the node group requested
95+
from the user with the list of node groups. Assumes there are no duplicate node groups.
9396
"""
94-
distributed_command = (
95-
"wget -O init.sh https://raw.githubusercontent.com/leptonai/scripts/main/lepton_env_to_pytorch.sh && "
96-
"chmod +x init.sh && "
97-
"source init.sh"
98-
)
99-
return distributed_command
97+
node_groups = client.nodegroup.list_all()
98+
node_group_map = {ng.metadata.name: ng for ng in node_groups}
99+
node_group_id = node_group_map[self.node_group]
100+
return node_group_id
101+
102+
def _valid_node_ids(self, node_group_id: DedicatedNodeGroup, client: APIClient) -> Set:
103+
"""
104+
Find all of the node IDs that are available within the requested node group.
105+
106+
Lepton will only schedule jobs on nodes that are part of the requested node
107+
group that match the user-specified resource shape. List all of the node IDs
108+
within the node group and set them as available nodes.
109+
"""
110+
valid_node_ids = set()
111+
node_ids = client.nodegroup.list_nodes(node_group_id)
112+
for node in node_ids:
113+
valid_node_ids.add(node.metadata.id_)
114+
115+
return valid_node_ids
100116

101117
def create_lepton_job(self, name: str):
102118
"""
@@ -114,16 +130,13 @@ def create_lepton_job(self, name: str):
114130
f"chmod +x {self.lepton_job_dir}/launch_script.sh && bash {self.lepton_job_dir}/launch_script.sh"
115131
]
116132

117-
# Get node groups
118-
node_groups = client.nodegroup.list_all()
119-
node_group_map = {ng.metadata.name: ng for ng in node_groups}
120-
node_group_id = node_group_map[self.node_group]
133+
# Get ID of requested node group
134+
node_group_id = self._node_group_id(client)
135+
if not node_group_id.metadata.id_:
136+
raise RuntimeError(f"Unable to find node group ID for node group {self.node_group}")
121137

122138
# Get node IDs
123-
valid_node_ids = set()
124-
node_ids = client.nodegroup.list_nodes(node_group_id)
125-
for node in node_ids:
126-
valid_node_ids.add(node.metadata.id_)
139+
valid_node_ids = self._valid_node_ids(node_group_id, client)
127140

128141
job_spec = LeptonJobUserSpec(
129142
resource_shape=self.resource_shape,
@@ -187,7 +200,13 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
187200
raise RuntimeError(f"Failed to create Lepton job")
188201

189202
job_id = job.metadata.id_
203+
204+
if not job_id:
205+
raise RuntimeError("Failed to retrieve job information")
190206
status = self.status(job_id)
207+
208+
if not status:
209+
raise RuntimeError("Failed to retrieve job status")
191210
return job_id, status
192211

193212
def nnodes(self) -> int:
@@ -206,7 +225,7 @@ def status(self, job_id: str) -> Optional[LeptonJobState]:
206225
client = APIClient()
207226
job = client.job.get(job_id)
208227

209-
if not job:
228+
if not job or not job.status:
210229
return LeptonJobState.Unknown
211230

212231
# Lepton marks a job as Running when at least one pod is running
@@ -237,6 +256,8 @@ def _first_replica(job_id: str) -> Replica:
237256

238257
for replica in replicas:
239258
replica_id = replica.metadata.id_
259+
if not replica_id:
260+
continue
240261
# The first replica has the pattern <job-id>-0-xxxxx
241262
# where xxxxx is a unique ID for each worker. Subsequent
242263
# workers increase the number between <job-id> and the
@@ -253,7 +274,7 @@ def _status(job_id: str):
253274
client = APIClient()
254275
job = client.job.get(job_id)
255276

256-
if not job:
277+
if not job or not job.status:
257278
return LeptonJobState.Unknown
258279

259280
# Lepton marks a job as Running when at least one pod is running

0 commit comments

Comments
 (0)