1414from typing import Dict
1515
1616import torch
17- from monarch ._rust_bindings .monarch_hyperactor .alloc import AllocConstraints , AllocSpec
18- from monarch ._src .actor .allocator import RemoteAllocator , TorchXRemoteAllocInitializer
19- from monarch .actor import Actor , current_rank , endpoint , ProcMesh , this_host
20- from monarch .tools import commands
21- from monarch .tools .components import hyperactor
22- from monarch .tools .config import Config
17+ from monarch .actor import Actor , current_rank , endpoint , HostMesh , ProcMesh , this_host
18+ from monarch .job import SlurmJob
2319from monarch .utils import setup_env_for_distributed
2420from torchtitan .config import ConfigManager , JobConfig
2521from torchtitan .tools .logging import init_logger , logger
2622from torchtitan .train import Trainer
2723from utils .failure import Failure , FailureActor , FailureController
2824
2925
30- # ==== Allocation boilerplate - much of this will be upstreamed into Monarch ====
26+ # ==== Allocation boilerplate ====
3127class MonarchSlurm :
32- # Cluster Configuration - update these values for your specific cluster
33- machine : str = "gpu.xlarge"
34- machine_memory : int = 2062607
3528 job_name_prefix : str = "monarch-torchft"
3629
3730 def __init__ (self ):
38- self .job_handles : Dict [str , str ] = {}
31+ self .job_handles : Dict [str , SlurmJob ] = {}
3932 atexit .register (self .kill_jobs )
4033
41- def get_config (self , mesh_name : str , nodes_per_mesh : int ) -> Config :
42- mesh = [f"{ mesh_name } :{ nodes_per_mesh } :{ MonarchSlurm .machine } " ]
43- # to enable relative import of utils on actors
44- current_dir = os .path .dirname (os .path .abspath (__file__ ))
45- env = {"PYTHONPATH" : current_dir }
46-
47- appdef = hyperactor .host_mesh (meshes = mesh , env = env )
48-
49- for role in appdef .roles :
50- role .resource .memMB = MonarchSlurm .machine_memory
51-
52- return Config (scheduler = "slurm" , appdef = appdef )
53-
54- async def get_or_create_job (self , mesh_name : str , nodes_per_mesh : int = 1 ) -> None :
55- config = self .get_config (mesh_name , nodes_per_mesh )
56- job_name = f"{ MonarchSlurm .job_name_prefix } -{ mesh_name } "
57- server_spec = await commands .get_or_create (job_name , config , force_restart = True )
58- self .job_handles [mesh_name ] = server_spec .name
34+ async def get_or_create_job (
35+ self , mesh_name : str , nodes_per_mesh : int = 1 , gpus_per_node : int = 8
36+ ) -> None :
37+ job = SlurmJob (
38+ meshes = {mesh_name : nodes_per_mesh },
39+ gpus_per_node = gpus_per_node ,
40+ job_name = f"{ self .job_name_prefix } -{ mesh_name } " ,
41+ )
42+ job .apply ()
43+ self .job_handles [mesh_name ] = job
5944
6045 def kill_jobs (self ):
6146 for mesh_name in self .job_handles .keys ():
6247 self .kill_job (mesh_name )
6348
6449 def kill_job (self , mesh_name : str ):
6550 try :
66- job_handle = self .job_handles [mesh_name ]
51+ job = self .job_handles [mesh_name ]
6752 logger .info (f"Destroying job for mesh { mesh_name } " )
68- commands .kill (f"slurm:/// { job_handle } " )
53+ job .kill ()
6954 except Exception as e :
70- logger .warning (f"Failed to destroy job for { mesh_name } : { e } " )
71-
72- def proc_mesh (
73- self ,
74- mesh_name : str ,
75- num_hosts : int = 1 ,
76- num_gpus : int = 8 ,
77- ) -> ProcMesh :
78- allocator = RemoteAllocator (
79- world_id = MonarchSlurm .job_name_prefix ,
80- initializer = TorchXRemoteAllocInitializer (
81- f"slurm:///{ self .job_handles [mesh_name ]} "
82- ),
83- )
84- alloc = allocator .allocate (
85- AllocSpec (AllocConstraints (), hosts = num_hosts , gpus = num_gpus )
86- )
55+ logger .exception (f"Failed to destroy job for { mesh_name } : { e } " )
8756
88- return ProcMesh .from_alloc (alloc )
57+ def proc_mesh (self , mesh_name : str , num_procs : int ) -> ProcMesh :
58+ job = self .job_handles [mesh_name ]
59+ cached_path = f".{ mesh_name } /job_state.pkl"
60+ mesh : HostMesh = getattr (job .state (cached_path = cached_path ), mesh_name )
61+ proc_mesh = mesh .spawn_procs ({"gpus" : num_procs })
62+ return proc_mesh
8963
9064
9165# ==== allocation boilerplate ====
@@ -177,34 +151,31 @@ async def start_replica(self) -> None:
177151 init_logger ()
178152 logger .info (f"{ self .uid } Spawning trainers" )
179153
180- trainers_proc_mesh : ProcMesh | None = None
181154 try :
182155 trainers_proc_mesh = self .scheduler .proc_mesh (
183156 f"replica_{ self .replica_id } " ,
184- self .spec .hosts_per_replica ,
185- self .spec .gpus_per_node ,
186- )
187- await trainers_proc_mesh .logging_option (stream_to_client = True )
188- await setup_env_for_distributed (trainers_proc_mesh )
189-
190- training_actors = trainers_proc_mesh .spawn (
191- "training_actors" ,
192- TrainingActor ,
193- self .spec .job_config ,
194- self .replica_id ,
157+ num_procs = self .spec .gpus_per_node ,
195158 )
196159
197- self . failure_actors = trainers_proc_mesh . spawn (
198- "failure_actors" , FailureActor
199- )
160+ async with trainers_proc_mesh :
161+ await trainers_proc_mesh . logging_option ( stream_to_client = True )
162+ await setup_env_for_distributed ( trainers_proc_mesh )
200163
201- logger .info (f"{ self .uid } Starting trainers" )
202- await training_actors .start_training .call (self .spec .lighthouse_address )
203- await trainers_proc_mesh .stop ()
204- except Exception as e :
205- if trainers_proc_mesh :
206- await trainers_proc_mesh .stop ()
207- raise e
164+ training_actors = trainers_proc_mesh .spawn (
165+ "training_actors" ,
166+ TrainingActor ,
167+ self .spec .job_config ,
168+ self .replica_id ,
169+ )
170+
171+ self .failure_actors = trainers_proc_mesh .spawn (
172+ "failure_actors" , FailureActor
173+ )
174+
175+ logger .info (f"{ self .uid } Starting trainers" )
176+ await training_actors .start_training .call (self .spec .lighthouse_address )
177+ except Exception :
178+ raise
208179
209180 @endpoint
210181 async def inject_failure (self , failure_type : Failure ):
@@ -216,8 +187,7 @@ async def inject_failure(self, failure_type: Failure):
216187
217188 await self .failure_actors .fail .choose (failure_type )
218189 except Exception as e :
219- error_msg = f"{ self .uid } Injected failure: { e } "
220- logger .error (error_msg )
190+ logger .exception (f"{ self .uid } Injected failure: { e } " )
221191 else :
222192 error_msg = f"{ self .uid } No failure actors available"
223193 logger .error (error_msg )
@@ -268,7 +238,7 @@ async def start_training(self) -> None:
268238 async def start_lighthouse (self ) -> None :
269239 if self .spec .remote_lighthouse :
270240 await self .scheduler .get_or_create_job ("lighthouse" )
271- self .lighthouse_mesh = self .scheduler .proc_mesh ("lighthouse" , num_gpus = 1 )
241+ self .lighthouse_mesh = self .scheduler .proc_mesh ("lighthouse" , num_procs = 1 )
272242 else :
273243 self .lighthouse_mesh = this_host ().spawn_procs ({"gpus" : 1 })
274244
@@ -287,7 +257,7 @@ async def stop_lighthouse(self) -> None:
287257 await self .lighthouse_mesh .stop ()
288258 logger .info ("[Controller] Lighthouse stopped" )
289259 except Exception as e :
290- logger .warning (f"[Controller] Failed to stop lighthouse: { e } " )
260+ logger .exception (f"[Controller] Failed to stop lighthouse: { e } " )
291261
292262 async def _run_replica (self , replica_id : int , attempt_number : int ) -> None :
293263 if attempt_number >= MAX_ATTEMPT :
@@ -300,7 +270,7 @@ async def _run_replica(self, replica_id: int, attempt_number: int) -> None:
300270 await self ._teardown (replica_id )
301271 except Exception as e :
302272 await self ._teardown (replica_id )
303- logger .info (f"[Controller] replica { replica_id } failed: { e } " )
273+ logger .exception (f"[Controller] replica { replica_id } failed: { e } " )
304274 await self ._run_replica (replica_id , attempt_number + 1 )
305275
306276 async def _spin_up_replica (self , replica_id : int , attempt_number : int = 0 ) -> None :
@@ -332,11 +302,18 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
332302 async def _teardown (self , replica_id : int ) -> None :
333303 try :
334304 replica = self .replicas [replica_id ]
335- await replica .proc_mesh .stop ()
305+ try :
306+ await replica .proc_mesh .stop ()
307+ except Exception as e :
308+ logger .exception (
309+ f"[Controller] Failed to stop replica { replica_id } , it may already be stopped. { e } "
310+ )
336311 del self .replicas [replica_id ]
337312 del replica .proc_mesh
338313 except Exception as e :
339- logger .error (f"[Controller] Failed to _teardown replica { replica_id } : { e } " )
314+ logger .exception (
315+ f"[Controller] Failed to teardown replica { replica_id } : { e } "
316+ )
340317
341318
342319# === CLI / CONFIG === #
0 commit comments