55import time
66from dataclasses import dataclass , field
77from pathlib import Path
8- from typing import Any , Optional , Type
8+ from typing import Any , Optional , Set , Type
99
1010from invoke .context import Context
1111from leptonai .api .v1 .client import APIClient
1212from leptonai .api .v1 .types .affinity import LeptonResourceAffinity
1313from leptonai .api .v1 .types .common import Metadata
14+ from leptonai .api .v1 .types .dedicated_node_group import DedicatedNodeGroup
1415from leptonai .api .v1 .types .deployment import EnvVar , LeptonContainer , Mount
1516from leptonai .api .v1 .types .job import (LeptonJob , LeptonJobState ,
1617 LeptonJobUserSpec )
@@ -86,17 +87,32 @@ def move_data(self, sleep: float = 10) -> None:
8687 remote_path = relative_path
8788 )
8889
89- def setup_distributed_pytorch (self ) -> str :
90+ def _node_group_id (self , client : APIClient ) -> DedicatedNodeGroup :
9091 """
91- Runs a custom script from Lepton to setup the distributed PyTorch
92- environment variables required for distributed PyTorch jobs.
92+ Find the node group ID for the passed node group.
93+
94+ Lists all node groups available to the user and matches the node group requested
95+ from the user with the list of node groups. Assumes there are no duplicate node groups.
9396 """
94- distributed_command = (
95- "wget -O init.sh https://raw.githubusercontent.com/leptonai/scripts/main/lepton_env_to_pytorch.sh && "
96- "chmod +x init.sh && "
97- "source init.sh"
98- )
99- return distributed_command
97+ node_groups = client .nodegroup .list_all ()
98+ node_group_map = {ng .metadata .name : ng for ng in node_groups }
99+ node_group_id = node_group_map [self .node_group ]
100+ return node_group_id
101+
102+ def _valid_node_ids (self , node_group_id : DedicatedNodeGroup , client : APIClient ) -> Set :
103+ """
104+ Find all of the node IDs that are available within the requested node group.
105+
106+ Lepton will only schedule jobs on nodes that are part of the requested node
107+ group that match the user-specified resource shape. List all of the node IDs
108+ within the node group and set them as available nodes.
109+ """
110+ valid_node_ids = set ()
111+ node_ids = client .nodegroup .list_nodes (node_group_id )
112+ for node in node_ids :
113+ valid_node_ids .add (node .metadata .id_ )
114+
115+ return valid_node_ids
100116
101117 def create_lepton_job (self , name : str ):
102118 """
@@ -114,16 +130,13 @@ def create_lepton_job(self, name: str):
114130 f"chmod +x { self .lepton_job_dir } /launch_script.sh && bash { self .lepton_job_dir } /launch_script.sh"
115131 ]
116132
117- # Get node groups
118- node_groups = client . nodegroup . list_all ( )
119- node_group_map = { ng .metadata .name : ng for ng in node_groups }
120- node_group_id = node_group_map [ self .node_group ]
133+ # Get ID of requested node group
134+ node_group_id = self . _node_group_id ( client )
135+ if not node_group_id .metadata .id_ :
136+ raise RuntimeError ( f"Unable to find node group ID for node group { self .node_group } " )
121137
122138 # Get node IDs
123- valid_node_ids = set ()
124- node_ids = client .nodegroup .list_nodes (node_group_id )
125- for node in node_ids :
126- valid_node_ids .add (node .metadata .id_ )
139+ valid_node_ids = self ._valid_node_ids (node_group_id , client )
127140
128141 job_spec = LeptonJobUserSpec (
129142 resource_shape = self .resource_shape ,
@@ -187,7 +200,13 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
187200 raise RuntimeError (f"Failed to create Lepton job" )
188201
189202 job_id = job .metadata .id_
203+
204+ if not job_id :
205+ raise RuntimeError ("Failed to retrieve job information" )
190206 status = self .status (job_id )
207+
208+ if not status :
209+ raise RuntimeError ("Failed to retrieve job status" )
191210 return job_id , status
192211
193212 def nnodes (self ) -> int :
@@ -206,7 +225,7 @@ def status(self, job_id: str) -> Optional[LeptonJobState]:
206225 client = APIClient ()
207226 job = client .job .get (job_id )
208227
209- if not job :
228+ if not job or not job . status :
210229 return LeptonJobState .Unknown
211230
212231 # Lepton marks a job as Running when at least one pod is running
@@ -237,6 +256,8 @@ def _first_replica(job_id: str) -> Replica:
237256
238257 for replica in replicas :
239258 replica_id = replica .metadata .id_
259+ if not replica_id :
260+ continue
240261 # The first replica has the pattern <job-id>-0-xxxxx
241262 # where xxxxx is a unique ID for each worker. Subsequent
242263 # workers increase the number between <job-id> and the
@@ -253,7 +274,7 @@ def _status(job_id: str):
253274 client = APIClient ()
254275 job = client .job .get (job_id )
255276
256- if not job :
277+ if not job or not job . status :
257278 return LeptonJobState .Unknown
258279
259280 # Lepton marks a job as Running when at least one pod is running
0 commit comments