@@ -113,6 +113,12 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
113
113
return values
114
114
115
115
116
+ class CkptConfig (BaseConfig ):
117
+ resume : str | None = None
118
+ interval : int | None = None
119
+ path : str = "outputs"
120
+
121
+
116
122
class Config (BaseConfig ):
117
123
path_model : str = "PrimeIntellect/llama-150m-fresh"
118
124
torch_compile : bool = True
@@ -133,9 +139,7 @@ class Config(BaseConfig):
133
139
# Checkpointing and logging
134
140
project : str = "hivemind_debug"
135
141
log_activations_steps : int | None = None
136
- resume_from_checkpoint : str | None = None
137
- checkpoint_interval : int | None = None
138
- checkpoint_path : str = "outputs"
142
+ ckpt : CkptConfig = CkptConfig ()
139
143
# Hivemind
140
144
hv : HvConfig | None = None # if no hv config then hivemind is disabled
141
145
fake_data : bool = False
@@ -228,7 +232,7 @@ def train(config: Config):
228
232
log_visible_maddrs (dht .get_visible_maddrs (), only_p2p = False )
229
233
230
234
if local_rank == 0 :
231
- check_checkpoint_path_access (config .checkpoint_path , rank , config .hv .world_rank if config .hv else None )
235
+ check_checkpoint_path_access (config .ckpt . path , rank , config .hv .world_rank if config .hv else None )
232
236
233
237
# DataLoader preparation
234
238
tokenizer = AutoTokenizer .from_pretrained ("mistralai/Mistral-7B-v0.1" , use_fast = True )
@@ -279,16 +283,14 @@ def scheduler_fn(opt):
279
283
)
280
284
281
285
if config .hv is not None :
282
- if config .resume_from_checkpoint :
286
+ if config .ckpt . resume :
283
287
# We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer
284
288
# This is because the DiLoCoOptimizer makes a copy of the model parameters for the state averager which is hard to update later
285
289
# We also need to do this on follower workers so that the world_messenger has friends to talk to when it does its two loads
286
290
# Otherwise the world messenger will get lonely and hang
287
291
fake_optimizer = inner_optimizer (model .parameters ())
288
292
last_loss = load_checkpoint (
289
- checkpoint_path = os .path .join (
290
- config .resume_from_checkpoint , get_diloco_rank_dir_name (config .hv .world_rank )
291
- ),
293
+ checkpoint_path = os .path .join (config .ckpt .resume , get_diloco_rank_dir_name (config .hv .world_rank )),
292
294
model = model ,
293
295
optimizer = fake_optimizer ,
294
296
)
@@ -325,11 +327,9 @@ def scheduler_fn(opt):
325
327
optimizer .inner_optimizer
326
328
) # scheduler(optimizer) should work but better to make it explicit here
327
329
328
- if config .resume_from_checkpoint :
330
+ if config .ckpt . resume :
329
331
last_loss = load_checkpoint (
330
- checkpoint_path = os .path .join (
331
- config .resume_from_checkpoint , get_diloco_rank_dir_name (config .hv .world_rank )
332
- ),
332
+ checkpoint_path = os .path .join (config .ckpt .resume , get_diloco_rank_dir_name (config .hv .world_rank )),
333
333
model = model ,
334
334
optimizer = optimizer .inner_optimizer ,
335
335
scheduler = scheduler ,
@@ -344,9 +344,9 @@ def scheduler_fn(opt):
344
344
else :
345
345
optimizer = inner_optimizer (model .parameters ())
346
346
scheduler = scheduler_fn (optimizer )
347
- if config .resume_from_checkpoint :
347
+ if config .ckpt . resume :
348
348
last_loss = load_checkpoint (
349
- checkpoint_path = config .resume_from_checkpoint ,
349
+ checkpoint_path = config .ckpt . resume ,
350
350
model = model ,
351
351
optimizer = optimizer ,
352
352
scheduler = scheduler ,
@@ -357,7 +357,7 @@ def scheduler_fn(opt):
357
357
else :
358
358
start_step = 0
359
359
360
- if config .resume_from_checkpoint :
360
+ if config .ckpt . resume :
361
361
log (f"Resumed from checkpoint at step { start_step } with loss { last_loss } " )
362
362
363
363
model .train ()
@@ -481,9 +481,9 @@ def scheduler_fn(opt):
481
481
)
482
482
483
483
# Save checkpoint every 'checkpoint_interval' steps
484
- if config .checkpoint_interval is not None and real_step % config .checkpoint_interval == 0 :
484
+ if config .ckpt . interval is not None and real_step % config .ckpt . interval == 0 :
485
485
log (f"saving at step { real_step } , step { step + 1 } " )
486
- ckpt_path = os .path .join (config .checkpoint_path , f"model_step_{ int (real_step )} " )
486
+ ckpt_path = os .path .join (config .ckpt . path , f"model_step_{ int (real_step )} " )
487
487
488
488
if config .hv :
489
489
ckpt_path = os .path .join (ckpt_path , get_diloco_rank_dir_name (config .hv .world_rank ))
0 commit comments