Skip to content

Commit 7d3d685

Browse files
committed
Add error handling to LeptonExecutor
Handle more possible failure scenarios for the LeptonExecutor where the code could run into a bad state and the user should be alerted with helpful debug info. Signed-Off-By: Robert Clark <roclark@nvidia.com>
1 parent 3a1aec4 commit 7d3d685

File tree

1 file changed

+42
-21
lines changed

1 file changed

+42
-21
lines changed

nemo_run/core/execution/lepton.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import time
66
from dataclasses import dataclass, field
77
from pathlib import Path
8-
from typing import Any, Optional, Type
8+
from typing import Any, Optional, Set, Type
99

1010
from invoke.context import Context
1111
from leptonai.api.v1.client import APIClient
1212
from leptonai.api.v1.types.affinity import LeptonResourceAffinity
1313
from leptonai.api.v1.types.common import Metadata
14+
from leptonai.api.v1.types.dedicated_node_group import DedicatedNodeGroup
1415
from leptonai.api.v1.types.deployment import EnvVar, LeptonContainer, Mount
1516
from leptonai.api.v1.types.job import LeptonJob, LeptonJobState, LeptonJobUserSpec
1617
from leptonai.api.v1.types.replica import Replica
@@ -85,17 +86,32 @@ def move_data(self, sleep: float = 10) -> None:
8586
remote_path=relative_path
8687
)
8788

88-
def setup_distributed_pytorch(self) -> str:
89+
def _node_group_id(self, client: APIClient) -> DedicatedNodeGroup:
8990
"""
90-
Runs a custom script from Lepton to setup the distributed PyTorch
91-
environment variables required for distributed PyTorch jobs.
91+
Find the node group ID for the passed node group.
92+
93+
Lists all node groups available to the user and matches the node group requested
94+
from the user with the list of node groups. Assumes there are no duplicate node groups.
9295
"""
93-
distributed_command = (
94-
"wget -O init.sh https://raw.githubusercontent.com/leptonai/scripts/main/lepton_env_to_pytorch.sh && "
95-
"chmod +x init.sh && "
96-
"source init.sh"
97-
)
98-
return distributed_command
96+
node_groups = client.nodegroup.list_all()
97+
node_group_map = {ng.metadata.name: ng for ng in node_groups}
98+
node_group_id = node_group_map[self.node_group]
99+
return node_group_id
100+
101+
def _valid_node_ids(self, node_group_id: DedicatedNodeGroup, client: APIClient) -> Set:
102+
"""
103+
Find all of the node IDs that are available within the requested node group.
104+
105+
Lepton will only schedule jobs on nodes that are part of the requested node
106+
group that match the user-specified resource shape. List all of the node IDs
107+
within the node group and set them as available nodes.
108+
"""
109+
valid_node_ids = set()
110+
node_ids = client.nodegroup.list_nodes(node_group_id)
111+
for node in node_ids:
112+
valid_node_ids.add(node.metadata.id_)
113+
114+
return valid_node_ids
99115

100116
def create_lepton_job(self, name: str):
101117
"""
@@ -111,16 +127,13 @@ def create_lepton_job(self, name: str):
111127
f"chmod +x {self.lepton_job_dir}/launch_script.sh && bash {self.lepton_job_dir}/launch_script.sh"
112128
]
113129

114-
# Get node groups
115-
node_groups = client.nodegroup.list_all()
116-
node_group_map = {ng.metadata.name: ng for ng in node_groups}
117-
node_group_id = node_group_map[self.node_group]
130+
# Get ID of requested node group
131+
node_group_id = self._node_group_id(client)
132+
if not node_group_id.metadata.id_:
133+
raise RuntimeError(f"Unable to find node group ID for node group {self.node_group}")
118134

119135
# Get node IDs
120-
valid_node_ids = set()
121-
node_ids = client.nodegroup.list_nodes(node_group_id)
122-
for node in node_ids:
123-
valid_node_ids.add(node.metadata.id_)
136+
valid_node_ids = self._valid_node_ids(node_group_id, client)
124137

125138
job_spec = LeptonJobUserSpec(
126139
resource_shape=self.resource_shape,
@@ -171,10 +184,16 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
171184
logger.info("Creating distributed workload")
172185
job = self.create_lepton_job(name)
173186
if not job:
174-
raise RuntimeError(f"Failed to create Lepton job")
187+
raise RuntimeError("Failed to create Lepton job")
175188

176189
job_id = job.metadata.id_
190+
191+
if not job_id:
192+
raise RuntimeError("Failed to retrieve job information")
177193
status = self.status(job_id)
194+
195+
if not status:
196+
raise RuntimeError("Failed to retrieve job status")
178197
return job_id, status
179198

180199
def nnodes(self) -> int:
@@ -193,7 +212,7 @@ def status(self, job_id: str) -> Optional[LeptonJobState]:
193212
client = APIClient()
194213
job = client.job.get(job_id)
195214

196-
if not job:
215+
if not job or not job.status:
197216
return LeptonJobState.Unknown
198217

199218
# Lepton marks a job as Running when at least one pod is running
@@ -224,6 +243,8 @@ def _first_replica(job_id: str) -> Replica:
224243

225244
for replica in replicas:
226245
replica_id = replica.metadata.id_
246+
if not replica_id:
247+
continue
227248
# The first replica has the pattern <job-id>-0-xxxxx
228249
# where xxxxx is a unique ID for each worker. Subsequent
229250
# workers increase the number between <job-id> and the
@@ -240,7 +261,7 @@ def _status(job_id: str):
240261
client = APIClient()
241262
job = client.job.get(job_id)
242263

243-
if not job:
264+
if not job or not job.status:
244265
return LeptonJobState.Unknown
245266

246267
# Lepton marks a job as Running when at least one pod is running

0 commit comments

Comments
 (0)