Skip to content

Commit 69fdb96

Browse files
committed
feat: auto resume from latest ckpt
1 parent 0b95922 commit 69fdb96

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

open_diloco/train_fsdp.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
TIMEOUT_NCCL_MINUTES = os.environ.get("TIMEOUT_NCCL_MINUTES", 120)
5959
TARGET_LAYER_ACTIVATIONS = ["self_attn", "lm_head"]
6060
TEST_VOCAB_SIZE = 1024
61+
CKPT_PREFIX = "model_step"
6162

6263

6364
# Function to initialize the distributed process group
@@ -114,10 +115,35 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
114115

115116

116117
class CkptConfig(BaseConfig):
117-
resume: str | None = None
118+
resume: str | bool | None = None # if resume is a boolean, it means we should resume from the last checkpoint
118119
interval: int | None = None
119120
path: str = "outputs"
120121

122+
def get_resume_path(self):
123+
if self.resume is None:
124+
raise ValueError("Resume path is not set")
125+
elif isinstance(self.resume, bool):
126+
# Using fsspec to list directory contents
127+
fs = GenericFileSystem()
128+
129+
def filter_ckpt_files(f):
130+
if CKPT_PREFIX not in f:
131+
return False
132+
else:
133+
try:
134+
int(f.split("_")[-1])
135+
return True
136+
except ValueError:
137+
return False
138+
139+
ckpt_files = [f for f in fs.ls(self.path, detail=False) if filter_ckpt_files(f)]
140+
# Regex to extract numbers following the CKPT_PREFIX and an underscore
141+
# f is usually something like this "file:///hello/model_step_100000"
142+
latest_ckpt = max(ckpt_files, key=lambda f: int(f.split("_")[-1]))
143+
return latest_ckpt
144+
145+
return self.resume
146+
121147

122148
class Config(BaseConfig):
123149
path_model: str = "PrimeIntellect/llama-150m-fresh"
@@ -290,7 +316,9 @@ def scheduler_fn(opt):
290316
# Otherwise the world messenger will get lonely and hang
291317
fake_optimizer = inner_optimizer(model.parameters())
292318
last_loss = load_checkpoint(
293-
checkpoint_path=os.path.join(config.ckpt.resume, get_diloco_rank_dir_name(config.hv.world_rank)),
319+
checkpoint_path=os.path.join(
320+
config.ckpt.get_resume_path(), get_diloco_rank_dir_name(config.hv.world_rank)
321+
),
294322
model=model,
295323
optimizer=fake_optimizer,
296324
)
@@ -329,7 +357,9 @@ def scheduler_fn(opt):
329357

330358
if config.ckpt.resume:
331359
last_loss = load_checkpoint(
332-
checkpoint_path=os.path.join(config.ckpt.resume, get_diloco_rank_dir_name(config.hv.world_rank)),
360+
checkpoint_path=os.path.join(
361+
config.ckpt.get_resume_path(), get_diloco_rank_dir_name(config.hv.world_rank)
362+
),
333363
model=model,
334364
optimizer=optimizer.inner_optimizer,
335365
scheduler=scheduler,
@@ -346,7 +376,7 @@ def scheduler_fn(opt):
346376
scheduler = scheduler_fn(optimizer)
347377
if config.ckpt.resume:
348378
last_loss = load_checkpoint(
349-
checkpoint_path=config.ckpt.resume,
379+
checkpoint_path=config.ckpt.get_resume_path(),
350380
model=model,
351381
optimizer=optimizer,
352382
scheduler=scheduler,
@@ -483,7 +513,7 @@ def scheduler_fn(opt):
483513
# Save checkpoint every 'checkpoint_interval' steps
484514
if config.ckpt.interval is not None and real_step % config.ckpt.interval == 0:
485515
log(f"saving at step {real_step}, step {step+1}")
486-
ckpt_path = os.path.join(config.ckpt.path, f"model_step_{int(real_step)}")
516+
ckpt_path = os.path.join(config.ckpt.path, f"{CKPT_PREFIX}_{int(real_step)}")
487517

488518
if config.hv:
489519
ckpt_path = os.path.join(ckpt_path, get_diloco_rank_dir_name(config.hv.world_rank))

0 commit comments

Comments
 (0)