55import time
66from dataclasses import dataclass , field
77from pathlib import Path
8- from typing import Any , Optional , Type
8+ from typing import Any , Optional , Set , Type
99
1010from invoke .context import Context
1111from leptonai .api .v1 .client import APIClient
1212from leptonai .api .v1 .types .affinity import LeptonResourceAffinity
1313from leptonai .api .v1 .types .common import Metadata
14+ from leptonai .api .v1 .types .dedicated_node_group import DedicatedNodeGroup
1415from leptonai .api .v1 .types .deployment import EnvVar , LeptonContainer , Mount
1516from leptonai .api .v1 .types .job import LeptonJob , LeptonJobState , LeptonJobUserSpec
1617from leptonai .api .v1 .types .replica import Replica
@@ -85,17 +86,32 @@ def move_data(self, sleep: float = 10) -> None:
8586 remote_path = relative_path
8687 )
8788
88- def setup_distributed_pytorch (self ) -> str :
89+ def _node_group_id (self , client : APIClient ) -> DedicatedNodeGroup :
8990 """
90- Runs a custom script from Lepton to setup the distributed PyTorch
91- environment variables required for distributed PyTorch jobs.
91+ Find the node group ID for the passed node group.
92+
93+ Lists all node groups available to the user and matches the node group requested
94+ from the user with the list of node groups. Assumes there are no duplicate node groups.
9295 """
93- distributed_command = (
94- "wget -O init.sh https://raw.githubusercontent.com/leptonai/scripts/main/lepton_env_to_pytorch.sh && "
95- "chmod +x init.sh && "
96- "source init.sh"
97- )
98- return distributed_command
96+ node_groups = client .nodegroup .list_all ()
97+ node_group_map = {ng .metadata .name : ng for ng in node_groups }
98+ node_group_id = node_group_map [self .node_group ]
99+ return node_group_id
100+
101+ def _valid_node_ids (self , node_group_id : DedicatedNodeGroup , client : APIClient ) -> Set :
102+ """
103+ Find all of the node IDs that are available within the requested node group.
104+
105+ Lepton will only schedule jobs on nodes that are part of the requested node
106+ group that match the user-specified resource shape. List all of the node IDs
107+ within the node group and set them as available nodes.
108+ """
109+ valid_node_ids = set ()
110+ node_ids = client .nodegroup .list_nodes (node_group_id )
111+ for node in node_ids :
112+ valid_node_ids .add (node .metadata .id_ )
113+
114+ return valid_node_ids
99115
100116 def create_lepton_job (self , name : str ):
101117 """
@@ -111,16 +127,13 @@ def create_lepton_job(self, name: str):
111127 f"chmod +x { self .lepton_job_dir } /launch_script.sh && bash { self .lepton_job_dir } /launch_script.sh"
112128 ]
113129
114- # Get node groups
115- node_groups = client . nodegroup . list_all ( )
116- node_group_map = { ng .metadata .name : ng for ng in node_groups }
117- node_group_id = node_group_map [ self .node_group ]
130+ # Get ID of requested node group
131+ node_group_id = self . _node_group_id ( client )
132+ if not node_group_id .metadata .id_ :
133+ raise RuntimeError ( f"Unable to find node group ID for node group { self .node_group } " )
118134
119135 # Get node IDs
120- valid_node_ids = set ()
121- node_ids = client .nodegroup .list_nodes (node_group_id )
122- for node in node_ids :
123- valid_node_ids .add (node .metadata .id_ )
136+ valid_node_ids = self ._valid_node_ids (node_group_id , client )
124137
125138 job_spec = LeptonJobUserSpec (
126139 resource_shape = self .resource_shape ,
@@ -171,10 +184,16 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
171184 logger .info ("Creating distributed workload" )
172185 job = self .create_lepton_job (name )
173186 if not job :
174- raise RuntimeError (f "Failed to create Lepton job" )
187+ raise RuntimeError ("Failed to create Lepton job" )
175188
176189 job_id = job .metadata .id_
190+
191+ if not job_id :
192+ raise RuntimeError ("Failed to retrieve job information" )
177193 status = self .status (job_id )
194+
195+ if not status :
196+ raise RuntimeError ("Failed to retrieve job status" )
178197 return job_id , status
179198
180199 def nnodes (self ) -> int :
@@ -193,7 +212,7 @@ def status(self, job_id: str) -> Optional[LeptonJobState]:
193212 client = APIClient ()
194213 job = client .job .get (job_id )
195214
196- if not job :
215+ if not job or not job . status :
197216 return LeptonJobState .Unknown
198217
199218 # Lepton marks a job as Running when at least one pod is running
@@ -224,6 +243,8 @@ def _first_replica(job_id: str) -> Replica:
224243
225244 for replica in replicas :
226245 replica_id = replica .metadata .id_
246+ if not replica_id :
247+ continue
227248 # The first replica has the pattern <job-id>-0-xxxxx
228249 # where xxxxx is a unique ID for each worker. Subsequent
229250 # workers increase the number between <job-id> and the
@@ -240,7 +261,7 @@ def _status(job_id: str):
240261 client = APIClient ()
241262 job = client .job .get (job_id )
242263
243- if not job :
264+ if not job or not job . status :
244265 return LeptonJobState .Unknown
245266
246267 # Lepton marks a job as Running when at least one pod is running
0 commit comments