Skip to content

Commit ba63895

Browse files
committed
refactor checkpointing: remove date and run id
1 parent 269cf90 commit ba63895

File tree

1 file changed

+16
-26
lines changed

1 file changed

+16
-26
lines changed

open_diloco/train_fsdp.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from datasets.distributed import split_dataset_by_node
2323
from fsspec.generic import GenericFileSystem
2424
from torch.distributed import destroy_process_group, init_process_group
25+
2526
from torchdata.stateful_dataloader import StatefulDataLoader
2627
from transformers import (
2728
AutoTokenizer,
@@ -69,26 +70,18 @@ def log(message):
6970
logger.info(f"[rank {os.environ['LOCAL_RANK']}] {message}")
7071

7172

72-
def get_ckpt_folder(checkpoint_path, training_date, project, run_id):
73-
return os.path.join(checkpoint_path, training_date, project, run_id)
74-
75-
76-
def check_checkpoint_path_access(checkpoint_path: str, training_date, project, run_id, rank):
77-
dummy_file_path = os.path.join(
78-
get_ckpt_folder(
79-
checkpoint_path=checkpoint_path,
80-
training_date=training_date,
81-
project=project,
82-
run_id=run_id,
83-
),
84-
f"dummy_file_{rank}.txt",
85-
)
73+
def check_checkpoint_path_access(checkpoint_path: str, rank: int):
74+
dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt")
8675
with fsspec.open(dummy_file_path, "w") as f:
8776
f.write("This is a dummy file for testing access.")
8877
gfs = GenericFileSystem()
8978
gfs.rm(dummy_file_path)
9079

9180

81+
def get_diloco_rank_dir_name(world_rank_diloco: int) -> str:
82+
return f"diloco_rank_{world_rank_diloco}"
83+
84+
9285
class HvConfig(BaseConfig):
9386
outer_lr: float = 0.7
9487
local_steps: int = 500
@@ -202,10 +195,6 @@ def train(config: Config):
202195
assert batch_size % config.per_device_train_batch_size == 0
203196
gradient_accumulation_steps = batch_size // config.per_device_train_batch_size
204197

205-
training_date = datetime.datetime.now().strftime(
206-
"%Y-%m-%d"
207-
) # we define the data at the beginning of training in case the training take several days
208-
209198
if config.hv is not None:
210199
sharding_strategy = ShardingStrategy.NO_SHARD
211200
log("Hivemind is used, ShardingStrategy.NO_SHARD is used")
@@ -232,7 +221,7 @@ def train(config: Config):
232221
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=False)
233222

234223
if local_rank == 0:
235-
check_checkpoint_path_access(config.checkpoint_path, training_date, config.project, run_id, rank)
224+
check_checkpoint_path_access(config.checkpoint_path, rank)
236225

237226
# DataLoader preparation
238227
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
@@ -290,7 +279,9 @@ def scheduler_fn(opt):
290279
# Otherwise the world messenger will get lonely and hang
291280
fake_optimizer = inner_optimizer(model.parameters())
292281
last_loss = load_checkpoint(
293-
checkpoint_path=config.resume_from_checkpoint,
282+
checkpoint_path=os.path.join(
283+
config.resume_from_checkpoint, get_diloco_rank_dir_name(config.hv.world_rank)
284+
),
294285
model=model,
295286
optimizer=fake_optimizer,
296287
)
@@ -329,7 +320,9 @@ def scheduler_fn(opt):
329320

330321
if config.resume_from_checkpoint:
331322
last_loss = load_checkpoint(
332-
checkpoint_path=config.resume_from_checkpoint,
323+
checkpoint_path=os.path.join(
324+
config.resume_from_checkpoint, get_diloco_rank_dir_name(config.hv.world_rank)
325+
),
333326
model=model,
334327
optimizer=optimizer.inner_optimizer,
335328
scheduler=scheduler,
@@ -470,16 +463,13 @@ def scheduler_fn(opt):
470463
# Save checkpoint every 'checkpoint_interval' steps
471464
if config.checkpoint_interval is not None and real_step % config.checkpoint_interval == 0:
472465
log(f"saving at step {real_step}, step {step+1}")
473-
ckpt_path = os.path.join(
474-
get_ckpt_folder(config.checkpoint_path, training_date, config.project, run_id),
475-
f"model_step_{int(real_step)}",
476-
)
466+
ckpt_path = os.path.join(config.checkpoint_path, f"model_step_{int(real_step)}")
477467

478468
if world_messenger_hv:
479469
assert isinstance(optimizer, DiLoCoOptimizer)
480470
with optimizer.tracker.pause_updates():
481471
save_checkpoint(
482-
checkpoint_path=ckpt_path,
472+
checkpoint_path=os.path.join(ckpt_path, get_diloco_rank_dir_name(config.hv.world_rank)),
483473
model=model,
484474
optimizer=optimizer.inner_optimizer,
485475
scheduler=scheduler,

0 commit comments

Comments
 (0)