Skip to content

Commit 8b7b6a8

Browse files
committed
refactor: move code to ckpt_utils.py
1 parent 9cd4a43 commit 8b7b6a8

File tree

2 files changed

+80
-66
lines changed

2 files changed

+80
-66
lines changed

open_diloco/ckpt_utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,38 @@
11
import fsspec
2+
from pydantic_config import BaseConfig
23
import torch
34
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
45
import torch.distributed.checkpoint as dcp
56
import os
67
from torchdata.stateful_dataloader import StatefulDataLoader
8+
from fsspec.generic import GenericFileSystem
9+
710

811
GLOBAL_STATE_FILE = "global_state_dict.pt"
12+
CKPT_PREFIX = "model_step"
13+
14+
15+
class CkptConfig(BaseConfig):
16+
resume: str | bool | None = None # if resume is a boolean, it means we should resume from the last checkpoint
17+
interval: int | None = None
18+
path: str = "outputs"
19+
topk: int | None = None # how many checkpoints to keep
20+
21+
def get_resume_path(self):
22+
if self.resume is None:
23+
raise ValueError("Resume path is not set")
24+
elif isinstance(self.resume, bool):
25+
# Using fsspec to list directory contents
26+
fs = GenericFileSystem()
27+
ckpt_files = [f for f in fs.ls(self.path, detail=False) if filter_ckpt_files(f)]
28+
29+
if len(ckpt_files) == 0:
30+
raise ValueError(f"No checkpoints found in {self.path}")
31+
32+
latest_ckpt = max(ckpt_files, key=lambda f: int(f.split("_")[-1]))
33+
return latest_ckpt
34+
35+
return self.resume
936

1037

1138
def save_checkpoint(
@@ -117,3 +144,44 @@ def load_checkpoint(
117144
if scaler is not None:
118145
scaler.load_state_dict(global_state_dict["scaler"])
119146
return global_state_dict["loss"]
147+
148+
149+
def filter_ckpt_files(f):
150+
if CKPT_PREFIX not in f:
151+
return False
152+
else:
153+
try:
154+
int(f.split("_")[-1])
155+
return True
156+
except ValueError:
157+
return False
158+
159+
160+
def delete_old_checkpoints(checkpoint_path: str, topk: int) -> list[str]:
161+
fs = GenericFileSystem()
162+
ckpt_files = [f for f in fs.ls(checkpoint_path, detail=False) if filter_ckpt_files(f)]
163+
ckpt_files.sort(key=lambda x: int(x.split("_")[-1]))
164+
165+
ckpt_deleted = []
166+
for ckpt_file in ckpt_files[:-topk]:
167+
fs.rm(ckpt_file, recursive=True)
168+
ckpt_deleted.append(ckpt_file)
169+
return ckpt_deleted
170+
171+
172+
def check_checkpoint_path_access(checkpoint_path: str, rank: int, world_rank_hv: int | None = None):
173+
if world_rank_hv:
174+
dummy_file_path = os.path.join(
175+
checkpoint_path, get_diloco_rank_dir_name(world_rank_hv), f"dummy_file_{rank}.txt"
176+
)
177+
else:
178+
dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt")
179+
180+
with fsspec.open(dummy_file_path, "w") as f:
181+
f.write("This is a dummy file for testing access.")
182+
gfs = GenericFileSystem()
183+
gfs.rm(dummy_file_path)
184+
185+
186+
def get_diloco_rank_dir_name(world_rank_diloco: int) -> str:
187+
return f"diloco_rank_{world_rank_diloco}"

open_diloco/train_fsdp.py

Lines changed: 12 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@
1313
import datetime
1414
from typing import Any, Literal
1515

16-
import fsspec
1716
from pydantic import model_validator
1817
import torch
1918
import wandb
2019
from pydantic_config import parse_argv, BaseConfig
2120
from datasets import load_dataset
2221
from datasets.distributed import split_dataset_by_node
23-
from fsspec.generic import GenericFileSystem
2422
from torch.distributed import destroy_process_group, init_process_group
2523

2624
from torchdata.stateful_dataloader import StatefulDataLoader
@@ -38,7 +36,15 @@
3836
)
3937
from torch.distributed.device_mesh import DeviceMesh
4038
from torch.distributed import broadcast_object_list
41-
from open_diloco.ckpt_utils import load_checkpoint, save_checkpoint
39+
from open_diloco.ckpt_utils import (
40+
CKPT_PREFIX,
41+
CkptConfig,
42+
check_checkpoint_path_access,
43+
delete_old_checkpoints,
44+
get_diloco_rank_dir_name,
45+
load_checkpoint,
46+
save_checkpoint,
47+
)
4248
from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
4349

4450

@@ -58,7 +64,6 @@
5864
TIMEOUT_NCCL_MINUTES = os.environ.get("TIMEOUT_NCCL_MINUTES", 120)
5965
TARGET_LAYER_ACTIVATIONS = ["self_attn", "lm_head"]
6066
TEST_VOCAB_SIZE = 1024
61-
CKPT_PREFIX = "model_step"
6267

6368

6469
# Function to initialize the distributed process group
@@ -71,33 +76,6 @@ def log(message):
7176
logger.info(f"[rank {os.environ['LOCAL_RANK']}] {message}")
7277

7378

74-
def check_checkpoint_path_access(checkpoint_path: str, rank: int, world_rank_hv: int | None = None):
75-
if world_rank_hv:
76-
dummy_file_path = os.path.join(
77-
checkpoint_path, get_diloco_rank_dir_name(world_rank_hv), f"dummy_file_{rank}.txt"
78-
)
79-
else:
80-
dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt")
81-
82-
with fsspec.open(dummy_file_path, "w") as f:
83-
f.write("This is a dummy file for testing access.")
84-
gfs = GenericFileSystem()
85-
gfs.rm(dummy_file_path)
86-
87-
88-
def get_diloco_rank_dir_name(world_rank_diloco: int) -> str:
89-
return f"diloco_rank_{world_rank_diloco}"
90-
91-
92-
def delete_old_checkpoints(checkpoint_path: str, topk: int):
93-
fs = GenericFileSystem()
94-
ckpt_files = [f for f in fs.ls(checkpoint_path, detail=False) if filter_ckpt_files(f)]
95-
ckpt_files.sort(key=lambda x: int(x.split("_")[-1]))
96-
for ckpt_file in ckpt_files[:-topk]:
97-
log(f"Deleting old checkpoint {ckpt_file}")
98-
fs.rm(ckpt_file, recursive=True)
99-
100-
10179
class HvConfig(BaseConfig):
10280
outer_lr: float = 0.7
10381
local_steps: int = 500
@@ -123,40 +101,6 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
123101
return values
124102

125103

126-
def filter_ckpt_files(f):
127-
if CKPT_PREFIX not in f:
128-
return False
129-
else:
130-
try:
131-
int(f.split("_")[-1])
132-
return True
133-
except ValueError:
134-
return False
135-
136-
137-
class CkptConfig(BaseConfig):
138-
resume: str | bool | None = None # if resume is a boolean, it means we should resume from the last checkpoint
139-
interval: int | None = None
140-
path: str = "outputs"
141-
topk: int | None = None # how many checkpoints to keep
142-
143-
def get_resume_path(self):
144-
if self.resume is None:
145-
raise ValueError("Resume path is not set")
146-
elif isinstance(self.resume, bool):
147-
# Using fsspec to list directory contents
148-
fs = GenericFileSystem()
149-
ckpt_files = [f for f in fs.ls(self.path, detail=False) if filter_ckpt_files(f)]
150-
151-
if len(ckpt_files) == 0:
152-
raise ValueError(f"No checkpoints found in {self.path}")
153-
154-
latest_ckpt = max(ckpt_files, key=lambda f: int(f.split("_")[-1]))
155-
return latest_ckpt
156-
157-
return self.resume
158-
159-
160104
class Config(BaseConfig):
161105
path_model: str = "PrimeIntellect/llama-150m-fresh"
162106
torch_compile: bool = True
@@ -559,7 +503,9 @@ def scheduler_fn(opt):
559503
if local_rank == 0:
560504
# only the rank 0 deletes the checkpoints
561505
if config.ckpt.topk is not None:
562-
delete_old_checkpoints(config.ckpt.path, config.ckpt.topk)
506+
ckpt_deleted = delete_old_checkpoints(config.ckpt.path, config.ckpt.topk)
507+
if ckpt_deleted:
508+
log(f"Deleted old checkpoints: {ckpt_deleted}")
563509

564510
loss_batch = 0
565511

0 commit comments

Comments
 (0)