Skip to content

Commit a53b294

Browse files
committed
diloco ckpt check fix
1 parent ba63895 commit a53b294

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

open_diloco/train_fsdp.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,13 @@ def log(message):
7070
logger.info(f"[rank {os.environ['LOCAL_RANK']}] {message}")
7171

7272

73-
def check_checkpoint_path_access(checkpoint_path: str, rank: int):
74-
dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt")
73+
def check_checkpoint_path_access(checkpoint_path: str, rank: int, world_rank_hv: int | None = None):
74+
if world_rank_hv:
75+
dummy_file_path = os.path.join(
76+
checkpoint_path, get_diloco_rank_dir_name(world_rank_hv), f"dummy_file_{rank}.txt"
77+
)
78+
else:
79+
dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt")
7580
with fsspec.open(dummy_file_path, "w") as f:
7681
f.write("This is a dummy file for testing access.")
7782
gfs = GenericFileSystem()
@@ -221,7 +226,7 @@ def train(config: Config):
221226
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=False)
222227

223228
if local_rank == 0:
224-
check_checkpoint_path_access(config.checkpoint_path, rank)
229+
check_checkpoint_path_access(config.checkpoint_path, rank, config.hv.world_rank if config.hv else None)
225230

226231
# DataLoader preparation
227232
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)

0 commit comments

Comments
 (0)