2424from torchtitan .config import ConfigManager , JobConfig
2525from torchtitan .tools .logging import init_logger , logger
2626from torchtitan .train import Trainer
27+ from utils .failure import Failure , FailureActor , FailureController
2728
2829
2930# ==== Allocation boilerplate - much of this will be upstreamed into Monarch ====
3031class MonarchSlurm :
3132 # Cluster Configuration - update these values for your specific cluster
32- machine : str = "aws_g5.12xlarge "
33- machine_memory : int = 186777
33+ machine : str = "gpu.xlarge "
34+ machine_memory : int = 2062607
3435 job_name_prefix : str = "monarch-torchft"
3536
36- job_handles : Dict [str , str ] = {}
37+ def __init__ (self ):
38+ self .job_handles : Dict [str , str ] = {}
39+ atexit .register (self .kill_jobs )
3740
38- @classmethod
39- def get_config (cls , mesh_name : str , nodes_per_mesh : int ) -> Config :
41+ def get_config (self , mesh_name : str , nodes_per_mesh : int ) -> Config :
4042 mesh = [f"{ mesh_name } :{ nodes_per_mesh } :{ MonarchSlurm .machine } " ]
41- appdef = hyperactor .host_mesh (meshes = mesh )
43+ # to enable relative import of utils on actors
44+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
45+ env = {"PYTHONPATH" : current_dir }
46+
47+ appdef = hyperactor .host_mesh (meshes = mesh , env = env )
4248
4349 for role in appdef .roles :
4450 role .resource .memMB = MonarchSlurm .machine_memory
4551
4652 return Config (scheduler = "slurm" , appdef = appdef )
4753
48- @classmethod
49- async def get_or_create_job (cls , mesh_name : str , nodes_per_mesh : int = 1 ) -> None :
50- config = cls .get_config (mesh_name , nodes_per_mesh )
54+ async def get_or_create_job (self , mesh_name : str , nodes_per_mesh : int = 1 ) -> None :
55+ config = self .get_config (mesh_name , nodes_per_mesh )
5156 job_name = f"{ MonarchSlurm .job_name_prefix } -{ mesh_name } "
5257 server_spec = await commands .get_or_create (job_name , config , force_restart = True )
53- cls .job_handles [mesh_name ] = server_spec .name
58+ self .job_handles [mesh_name ] = server_spec .name
5459
55- @classmethod
56- def kill_jobs (cls ):
57- for mesh_name , job_handle in cls .job_handles .items ():
58- try :
59- logger .info (f"Destroying job for mesh { mesh_name } " )
60- commands .kill (f"slurm:///{ job_handle } " )
61- except Exception as e :
62- logger .warning (f"Failed to destroy job for { mesh_name } : { e } " )
60+ def kill_jobs (self ):
61+ for mesh_name in self .job_handles .keys ():
62+ self .kill_job (mesh_name )
63+
64+ def kill_job (self , mesh_name : str ):
65+ try :
66+ job_handle = self .job_handles [mesh_name ]
67+ logger .info (f"Destroying job for mesh { mesh_name } " )
68+ commands .kill (f"slurm:///{ job_handle } " )
69+ except Exception as e :
70+ logger .warning (f"Failed to destroy job for { mesh_name } : { e } " )
6371
64- @classmethod
6572 def proc_mesh (
66- cls ,
73+ self ,
6774 mesh_name : str ,
6875 num_hosts : int = 1 ,
6976 num_gpus : int = 8 ,
7077 ) -> ProcMesh :
7178 allocator = RemoteAllocator (
7279 world_id = MonarchSlurm .job_name_prefix ,
7380 initializer = TorchXRemoteAllocInitializer (
74- f"slurm:///{ cls .job_handles [mesh_name ]} "
81+ f"slurm:///{ self .job_handles [mesh_name ]} "
7582 ),
7683 )
7784 alloc = allocator .allocate (
@@ -140,6 +147,7 @@ class JobSpec:
140147 replica_count : int
141148 hosts_per_replica : int
142149 gpus_per_node : int
150+ with_failures : bool
143151 lighthouse_address : str = ""
144152
145153
@@ -154,16 +162,15 @@ class Replica:
154162# This does not currently benefit from being an actor, but will once
155163# Monarch supervision APIs are fleshed out.
156164class ReplicaActor (Actor ):
157- def __init__ (
158- self ,
159- spec : JobSpec ,
160- replica_id : int ,
161- ) -> None :
165+ def __init__ (self , spec : JobSpec , replica_id : int , scheduler : MonarchSlurm ) -> None :
162166 self .spec = deepcopy (spec )
163167 self .replica_id = replica_id
164168
165169 self .uid = f"[replica_{ replica_id } ]"
166170 self .spec .job_config .fault_tolerance .replica_id = self .replica_id
171+ self .scheduler = scheduler
172+
173+ self .failure_actors : FailureActor | None = None
167174
168175 @endpoint
169176 async def start_replica (self ) -> None :
@@ -172,7 +179,7 @@ async def start_replica(self) -> None:
172179
173180 trainers_proc_mesh : ProcMesh | None = None
174181 try :
175- trainers_proc_mesh = MonarchSlurm .proc_mesh (
182+ trainers_proc_mesh = self . scheduler .proc_mesh (
176183 f"replica_{ self .replica_id } " ,
177184 self .spec .hosts_per_replica ,
178185 self .spec .gpus_per_node ,
@@ -189,6 +196,10 @@ async def start_replica(self) -> None:
189196 self .replica_id ,
190197 )
191198
199+ self .failure_actors = trainers_proc_mesh .spawn (
200+ "failure_actors" , FailureActor
201+ )
202+
192203 logger .info (f"{ self .uid } Starting trainers" )
193204 await training_actors .start_training .call (self .spec .lighthouse_address )
194205 await trainers_proc_mesh .stop ()
@@ -197,13 +208,29 @@ async def start_replica(self) -> None:
197208 await trainers_proc_mesh .stop ()
198209 raise e
199210
211+ @endpoint
212+ async def inject_failure (self , failure_type : Failure ):
213+ if self .failure_actors :
214+ try :
215+ logger .info (
216+ f"{ self .uid } Injecting failure ({ failure_type } ) into random trainer"
217+ )
218+
219+ await self .failure_actors .fail .choose (failure_type )
220+ except Exception as e :
221+ error_msg = f"{ self .uid } Injected failure: { e } "
222+ logger .error (error_msg )
223+ else :
224+ error_msg = f"{ self .uid } No failure actors available"
225+ logger .error (error_msg )
226+
200227
201228# delay before re-creating proc mesh on existing job. change as needed.
202- PROC_ATTEMPT_DELAY = 10
229+ PROC_ATTEMPT_DELAY = 0
203230# proc attempts before getting a new scheduler allocation. change as needed.
204- PROC_ATTEMPTS = 2
231+ PROC_ATTEMPTS = 3
205232# attempts before failing training on replica. change as needed.
206- MAX_ATTEMPT = PROC_ATTEMPTS * 2
233+ MAX_ATTEMPT = PROC_ATTEMPTS * 3
207234
208235
209236class OrchestrationManager :
@@ -213,26 +240,37 @@ def __init__(self, spec: JobSpec) -> None:
213240 self .lighthouse_actor : LighthouseActor | None = None
214241 self .lighthouse_mesh : ProcMesh | None = None
215242
243+ self .scheduler = MonarchSlurm ()
244+
216245 async def start_training (self ) -> None :
217246 logger .info (
218247 f"[Controller] Creating training system with { self .spec .replica_count } replicas"
219248 )
220249
221250 for replica_id in range (self .spec .replica_count ):
222- await MonarchSlurm .get_or_create_job (
251+ await self . scheduler .get_or_create_job (
223252 f"replica_{ replica_id } " , self .spec .hosts_per_replica
224253 )
225254
226255 mesh_futures = {}
227256 for i in range (self .spec .replica_count ):
228257 mesh_futures [i ] = asyncio .create_task (self ._run_replica (i , 0 ))
229258
259+ failure_future = None
260+ if self .spec .with_failures :
261+ failure_future = asyncio .create_task (
262+ FailureController .execute_failures (self .replicas , self .scheduler )
263+ )
264+
230265 await asyncio .gather (* mesh_futures .values (), return_exceptions = True )
231266
267+ if failure_future :
268+ failure_future .cancel ()
269+
232270 async def start_lighthouse (self ) -> None :
233271 if self .spec .remote_lighthouse :
234- await MonarchSlurm .get_or_create_job ("lighthouse" )
235- self .lighthouse_mesh = MonarchSlurm .proc_mesh ("lighthouse" , num_gpus = 1 )
272+ await self . scheduler .get_or_create_job ("lighthouse" )
273+ self .lighthouse_mesh = self . scheduler .proc_mesh ("lighthouse" , num_gpus = 1 )
236274 else :
237275 self .lighthouse_mesh = this_host ().spawn_procs ({"gpus" : 1 })
238276
@@ -274,7 +312,8 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
274312 logger .info (
275313 f"[Controller] Replica { replica_id } has failed { attempt_number } times. Getting new allocation."
276314 )
277- await MonarchSlurm .get_or_create_job (
315+ self .scheduler .kill_job (f"replica_{ replica_id } " )
316+ await self .scheduler .get_or_create_job (
278317 f"replica_{ replica_id } " , self .spec .hosts_per_replica
279318 )
280319 delay = 0 if not attempt_number else PROC_ATTEMPT_DELAY
@@ -287,10 +326,7 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
287326 await replica_proc_mesh .logging_option (aggregate_window_sec = None )
288327
289328 replica_actor = replica_proc_mesh .spawn (
290- "replica_actor" ,
291- ReplicaActor ,
292- self .spec ,
293- replica_id ,
329+ "replica_actor" , ReplicaActor , self .spec , replica_id , self .scheduler
294330 )
295331
296332 replica = Replica (replica_id , replica_proc_mesh , replica_actor , attempt_number )
@@ -339,20 +375,25 @@ def parse_args() -> argparse.Namespace:
339375 parser .add_argument (
340376 "--model-config" ,
341377 type = str ,
342- default = os . path . join ( script_dir , "debug_model.toml" ) ,
343- help = f"Path to model configuration file (default: { os .path .join (script_dir , 'debug_model.toml' )} )" ,
378+ default = "debug_model.toml" ,
379+ help = f"Relative path to model configuration file (default: { os .path .join (script_dir , 'debug_model.toml' )} )" ,
344380 )
345381 parser .add_argument (
346382 "--dataset-path" ,
347383 type = str ,
348- default = os . path . join ( script_dir , "c4_test" ) ,
349- help = f"Path to training dataset (default: { os .path .join (script_dir , 'c4_test' )} )" ,
384+ default = "c4_test" ,
385+ help = f"Relative path to training dataset (default: { os .path .join (script_dir , 'c4_test' )} )" ,
350386 )
351387 parser .add_argument (
352388 "--tokenizer-path" ,
353389 type = str ,
354- default = os .path .join (script_dir , "tokenizer" ),
355- help = f"Path to tokenizer (default: { os .path .join (script_dir , 'tokenizer' )} )" ,
390+ default = "debug_tokenizer" ,
391+ help = f"Relative path to tokenizer (default: { os .path .join (script_dir , 'debug_tokenizer' )} )" ,
392+ )
393+ parser .add_argument (
394+ "--with-failures" ,
395+ action = "store_true" ,
396+ help = "Enable the failure injector utility (default: False)" ,
356397 )
357398
358399 return parser .parse_args ()
@@ -362,13 +403,14 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
362403 data_parallel_shard_degree = args .gpu_per_node * args .host_per_replica
363404
364405 output_path = "./outputs"
365- training_dataset = "c4_test"
406+ training_dataset = args . dataset_path . split ( "/" )[ - 1 ]
366407
408+ script_dir = os .path .dirname (os .path .abspath (__file__ ))
367409 default_args = [
368410 "--job.config_file" ,
369- args .model_config ,
411+ os . path . join ( script_dir , args .model_config ) ,
370412 "--model.tokenizer_path" ,
371- args .tokenizer_path ,
413+ os . path . join ( script_dir , args .tokenizer_path ) ,
372414 "--comm.trace_buf_size" ,
373415 "0" ,
374416 "--metrics.log_freq" ,
@@ -387,7 +429,7 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
387429 "--training.dataset" ,
388430 training_dataset ,
389431 "--training.dataset_path" ,
390- args .dataset_path ,
432+ os . path . join ( script_dir , args .dataset_path ) ,
391433 "--job.dump_folder" ,
392434 output_path ,
393435 "--metrics.enable_tensorboard" ,
@@ -402,6 +444,7 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
402444 replica_count = args .replica_count ,
403445 hosts_per_replica = args .host_per_replica ,
404446 gpus_per_node = args .gpu_per_node ,
447+ with_failures = args .with_failures ,
405448 )
406449
407450
@@ -414,7 +457,6 @@ async def main() -> None:
414457 args = parse_args ()
415458 job_spec = make_job_spec (args )
416459
417- atexit .register (MonarchSlurm .kill_jobs )
418460 orchestrator = OrchestrationManager (job_spec )
419461 try :
420462 await orchestrator .start_lighthouse ()
0 commit comments