Skip to content

Commit 9cd4a43

Browse files
committed
feat: allow to keap only k ckpt
1 parent 69fdb96 commit 9cd4a43

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

open_diloco/train_fsdp.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ def get_diloco_rank_dir_name(world_rank_diloco: int) -> str:
8989
return f"diloco_rank_{world_rank_diloco}"
9090

9191

92+
def delete_old_checkpoints(checkpoint_path: str, topk: int):
93+
fs = GenericFileSystem()
94+
ckpt_files = [f for f in fs.ls(checkpoint_path, detail=False) if filter_ckpt_files(f)]
95+
ckpt_files.sort(key=lambda x: int(x.split("_")[-1]))
96+
for ckpt_file in ckpt_files[:-topk]:
97+
log(f"Deleting old checkpoint {ckpt_file}")
98+
fs.rm(ckpt_file, recursive=True)
99+
100+
92101
class HvConfig(BaseConfig):
93102
outer_lr: float = 0.7
94103
local_steps: int = 500
@@ -114,31 +123,34 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
114123
return values
115124

116125

126+
def filter_ckpt_files(f):
127+
if CKPT_PREFIX not in f:
128+
return False
129+
else:
130+
try:
131+
int(f.split("_")[-1])
132+
return True
133+
except ValueError:
134+
return False
135+
136+
117137
class CkptConfig(BaseConfig):
118138
resume: str | bool | None = None # if resume is a boolean, it means we should resume from the last checkpoint
119139
interval: int | None = None
120140
path: str = "outputs"
141+
topk: int | None = None # how many checkpoints to keep
121142

122143
def get_resume_path(self):
123144
if self.resume is None:
124145
raise ValueError("Resume path is not set")
125146
elif isinstance(self.resume, bool):
126147
# Using fsspec to list directory contents
127148
fs = GenericFileSystem()
149+
ckpt_files = [f for f in fs.ls(self.path, detail=False) if filter_ckpt_files(f)]
128150

129-
def filter_ckpt_files(f):
130-
if CKPT_PREFIX not in f:
131-
return False
132-
else:
133-
try:
134-
int(f.split("_")[-1])
135-
return True
136-
except ValueError:
137-
return False
151+
if len(ckpt_files) == 0:
152+
raise ValueError(f"No checkpoints found in {self.path}")
138153

139-
ckpt_files = [f for f in fs.ls(self.path, detail=False) if filter_ckpt_files(f)]
140-
# Regex to extract numbers following the CKPT_PREFIX and an underscore
141-
# f is usually something like this "file:///hello/model_step_100000"
142154
latest_ckpt = max(ckpt_files, key=lambda f: int(f.split("_")[-1]))
143155
return latest_ckpt
144156

@@ -544,6 +556,11 @@ def scheduler_fn(opt):
544556
save_global_state=rank == 0,
545557
)
546558

559+
if local_rank == 0:
560+
# only the rank 0 deletes the checkpoints
561+
if config.ckpt.topk is not None:
562+
delete_old_checkpoints(config.ckpt.path, config.ckpt.topk)
563+
547564
loss_batch = 0
548565

549566
if config.max_steps is not None and real_step >= config.max_steps:

0 commit comments

Comments
 (0)