Skip to content

Commit 0b95922

Browse files
committed
refactor: rename and isolated ckpt args
1 parent 1b3ea8a commit 0b95922

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ torchrun --nproc_per_node=8 \
120120
--lr 4e-4 \
121121
--path_model PrimeIntellect/llama-150m-fresh \
122122
--log-activations-steps 200 \
123-
--checkpoint-interval 8000 \
124-
--checkpoint-path 150_ckpt
123+
--ckpt.interval 8000 \
124+
--ckpt.path 150_ckpt
125125
```
126126

127127
## 150m on 8 DiLoCo Worker with 500 local steps
@@ -139,8 +139,8 @@ In the `open_diloco` folder, run:
139139
--lr 4e-4 \
140140
--path-model PrimeIntellect/llama-150m-fresh \
141141
--log-activations-steps 250 \
142-
--checkpoint-interval 4975 \
143-
--checkpoint-path 150_ckpt
142+
--ckpt.interval 4975 \
143+
--ckpt.path 150_ckpt
144144
```
145145

146146
under the hood the `run_training.sh` script calls `train_fsdp.py` 8 times with the right argument to simulate 8 workers locally.
@@ -161,8 +161,8 @@ In the `open_diloco` folder, run:
161161
--lr 4e-4 \
162162
--path-model PrimeIntellect/llama-150m-fresh \
163163
--log-activations-steps 250 \
164-
--checkpoint-interval 4975 \
165-
--checkpoint-path 150_ckpt
164+
--ckpt.interval 4975 \
165+
--ckpt.path 150_ckpt
166166
```
167167

168168
## 1b Baseline
@@ -178,8 +178,8 @@ torchrun --nproc_per_node=8 \
178178
--project OpenDiLoCo \
179179
--lr 4e-4 \
180180
--path_model PrimeIntellect/llama-1b-fresh \
181-
--checkpoint-path 1b_ckpt \
182-
--checkpoint-interval 500
181+
--ckpt.path 1b_ckpt \
182+
--ckpt.interval 500
183183
```
184184

185185
## 1b on 4 DiLoCo Workers with 500 local steps
@@ -208,7 +208,7 @@ torchrun --nproc_per_node=8 \
208208
--hv.galaxy-size 4 \
209209
--hv.world-rank $WORLD_RANK \
210210
--checkpoint_interval 500 \
211-
--checkpoint-path 1b_diloco_ckpt
211+
--ckpt.path 1b_diloco_ckpt
212212
```
213213
## 1b on 4 DiLoCo Workers with 125 local steps
214214

@@ -238,7 +238,7 @@ torchrun --nproc_per_node=8 \
238238
--hv.galaxy-size 4 \
239239
--hv.world-rank $WORLD_RANK \
240240
--checkpoint_interval 500 \
241-
--checkpoint-path 1b_diloco_ckpt
241+
--ckpt.path 1b_diloco_ckpt
242242
```
243243

244244
# Use OpenDiLoCo in your own code

open_diloco/train_fsdp.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
113113
return values
114114

115115

116+
class CkptConfig(BaseConfig):
117+
resume: str | None = None
118+
interval: int | None = None
119+
path: str = "outputs"
120+
121+
116122
class Config(BaseConfig):
117123
path_model: str = "PrimeIntellect/llama-150m-fresh"
118124
torch_compile: bool = True
@@ -133,9 +139,7 @@ class Config(BaseConfig):
133139
# Checkpointing and logging
134140
project: str = "hivemind_debug"
135141
log_activations_steps: int | None = None
136-
resume_from_checkpoint: str | None = None
137-
checkpoint_interval: int | None = None
138-
checkpoint_path: str = "outputs"
142+
ckpt: CkptConfig = CkptConfig()
139143
# Hivemind
140144
hv: HvConfig | None = None # if no hv config then hivemind is disabled
141145
fake_data: bool = False
@@ -228,7 +232,7 @@ def train(config: Config):
228232
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=False)
229233

230234
if local_rank == 0:
231-
check_checkpoint_path_access(config.checkpoint_path, rank, config.hv.world_rank if config.hv else None)
235+
check_checkpoint_path_access(config.ckpt.path, rank, config.hv.world_rank if config.hv else None)
232236

233237
# DataLoader preparation
234238
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
@@ -279,16 +283,14 @@ def scheduler_fn(opt):
279283
)
280284

281285
if config.hv is not None:
282-
if config.resume_from_checkpoint:
286+
if config.ckpt.resume:
283287
# We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer
284288
# This is because the DiLoCoOptimizer makes a copy of the model parameters for the state averager which is hard to update later
285289
# We also need to do this on follower workers so that the world_messenger has friends to talk to when it does its two loads
286290
# Otherwise the world messenger will get lonely and hang
287291
fake_optimizer = inner_optimizer(model.parameters())
288292
last_loss = load_checkpoint(
289-
checkpoint_path=os.path.join(
290-
config.resume_from_checkpoint, get_diloco_rank_dir_name(config.hv.world_rank)
291-
),
293+
checkpoint_path=os.path.join(config.ckpt.resume, get_diloco_rank_dir_name(config.hv.world_rank)),
292294
model=model,
293295
optimizer=fake_optimizer,
294296
)
@@ -325,11 +327,9 @@ def scheduler_fn(opt):
325327
optimizer.inner_optimizer
326328
) # scheduler(optimizer) should work but better to make it explicit here
327329

328-
if config.resume_from_checkpoint:
330+
if config.ckpt.resume:
329331
last_loss = load_checkpoint(
330-
checkpoint_path=os.path.join(
331-
config.resume_from_checkpoint, get_diloco_rank_dir_name(config.hv.world_rank)
332-
),
332+
checkpoint_path=os.path.join(config.ckpt.resume, get_diloco_rank_dir_name(config.hv.world_rank)),
333333
model=model,
334334
optimizer=optimizer.inner_optimizer,
335335
scheduler=scheduler,
@@ -344,9 +344,9 @@ def scheduler_fn(opt):
344344
else:
345345
optimizer = inner_optimizer(model.parameters())
346346
scheduler = scheduler_fn(optimizer)
347-
if config.resume_from_checkpoint:
347+
if config.ckpt.resume:
348348
last_loss = load_checkpoint(
349-
checkpoint_path=config.resume_from_checkpoint,
349+
checkpoint_path=config.ckpt.resume,
350350
model=model,
351351
optimizer=optimizer,
352352
scheduler=scheduler,
@@ -357,7 +357,7 @@ def scheduler_fn(opt):
357357
else:
358358
start_step = 0
359359

360-
if config.resume_from_checkpoint:
360+
if config.ckpt.resume:
361361
log(f"Resumed from checkpoint at step {start_step} with loss {last_loss}")
362362

363363
model.train()
@@ -481,9 +481,9 @@ def scheduler_fn(opt):
481481
)
482482

483483
# Save checkpoint every 'checkpoint_interval' steps
484-
if config.checkpoint_interval is not None and real_step % config.checkpoint_interval == 0:
484+
if config.ckpt.interval is not None and real_step % config.ckpt.interval == 0:
485485
log(f"saving at step {real_step}, step {step+1}")
486-
ckpt_path = os.path.join(config.checkpoint_path, f"model_step_{int(real_step)}")
486+
ckpt_path = os.path.join(config.ckpt.path, f"model_step_{int(real_step)}")
487487

488488
if config.hv:
489489
ckpt_path = os.path.join(ckpt_path, get_diloco_rank_dir_name(config.hv.world_rank))

0 commit comments

Comments
 (0)