58
58
TIMEOUT_NCCL_MINUTES = os .environ .get ("TIMEOUT_NCCL_MINUTES" , 120 )
59
59
TARGET_LAYER_ACTIVATIONS = ["self_attn" , "lm_head" ]
60
60
TEST_VOCAB_SIZE = 1024
61
+ CKPT_PREFIX = "model_step"
61
62
62
63
63
64
# Function to initialize the distributed process group
@@ -114,10 +115,35 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
114
115
115
116
116
117
class CkptConfig (BaseConfig ):
117
- resume : str | None = None
118
+ resume : str | bool | None = None # if resume is a boolean, it means we should resume from the last checkpoint
118
119
interval : int | None = None
119
120
path : str = "outputs"
120
121
122
+ def get_resume_path (self ):
123
+ if self .resume is None :
124
+ raise ValueError ("Resume path is not set" )
125
+ elif isinstance (self .resume , bool ):
126
+ # Using fsspec to list directory contents
127
+ fs = GenericFileSystem ()
128
+
129
+ def filter_ckpt_files (f ):
130
+ if CKPT_PREFIX not in f :
131
+ return False
132
+ else :
133
+ try :
134
+ int (f .split ("_" )[- 1 ])
135
+ return True
136
+ except ValueError :
137
+ return False
138
+
139
+ ckpt_files = [f for f in fs .ls (self .path , detail = False ) if filter_ckpt_files (f )]
140
+ # Regex to extract numbers following the CKPT_PREFIX and an underscore
141
+ # f is usually something like this "file:///hello/model_step_100000"
142
+ latest_ckpt = max (ckpt_files , key = lambda f : int (f .split ("_" )[- 1 ]))
143
+ return latest_ckpt
144
+
145
+ return self .resume
146
+
121
147
122
148
class Config (BaseConfig ):
123
149
path_model : str = "PrimeIntellect/llama-150m-fresh"
@@ -290,7 +316,9 @@ def scheduler_fn(opt):
290
316
# Otherwise the world messenger will get lonely and hang
291
317
fake_optimizer = inner_optimizer (model .parameters ())
292
318
last_loss = load_checkpoint (
293
- checkpoint_path = os .path .join (config .ckpt .resume , get_diloco_rank_dir_name (config .hv .world_rank )),
319
+ checkpoint_path = os .path .join (
320
+ config .ckpt .get_resume_path (), get_diloco_rank_dir_name (config .hv .world_rank )
321
+ ),
294
322
model = model ,
295
323
optimizer = fake_optimizer ,
296
324
)
@@ -329,7 +357,9 @@ def scheduler_fn(opt):
329
357
330
358
if config .ckpt .resume :
331
359
last_loss = load_checkpoint (
332
- checkpoint_path = os .path .join (config .ckpt .resume , get_diloco_rank_dir_name (config .hv .world_rank )),
360
+ checkpoint_path = os .path .join (
361
+ config .ckpt .get_resume_path (), get_diloco_rank_dir_name (config .hv .world_rank )
362
+ ),
333
363
model = model ,
334
364
optimizer = optimizer .inner_optimizer ,
335
365
scheduler = scheduler ,
@@ -346,7 +376,7 @@ def scheduler_fn(opt):
346
376
scheduler = scheduler_fn (optimizer )
347
377
if config .ckpt .resume :
348
378
last_loss = load_checkpoint (
349
- checkpoint_path = config .ckpt .resume ,
379
+ checkpoint_path = config .ckpt .get_resume_path () ,
350
380
model = model ,
351
381
optimizer = optimizer ,
352
382
scheduler = scheduler ,
@@ -483,7 +513,7 @@ def scheduler_fn(opt):
483
513
# Save checkpoint every 'checkpoint_interval' steps
484
514
if config .ckpt .interval is not None and real_step % config .ckpt .interval == 0 :
485
515
log (f"saving at step { real_step } , step { step + 1 } " )
486
- ckpt_path = os .path .join (config .ckpt .path , f"model_step_ { int (real_step )} " )
516
+ ckpt_path = os .path .join (config .ckpt .path , f"{ CKPT_PREFIX } _ { int (real_step )} " )
487
517
488
518
if config .hv :
489
519
ckpt_path = os .path .join (ckpt_path , get_diloco_rank_dir_name (config .hv .world_rank ))
0 commit comments