13
13
import datetime
14
14
from typing import Any , Literal
15
15
16
- import fsspec
17
16
from pydantic import model_validator
18
17
import torch
19
18
import wandb
20
19
from pydantic_config import parse_argv , BaseConfig
21
20
from datasets import load_dataset
22
21
from datasets .distributed import split_dataset_by_node
23
- from fsspec .generic import GenericFileSystem
24
22
from torch .distributed import destroy_process_group , init_process_group
25
23
26
24
from torchdata .stateful_dataloader import StatefulDataLoader
38
36
)
39
37
from torch .distributed .device_mesh import DeviceMesh
40
38
from torch .distributed import broadcast_object_list
41
- from open_diloco .ckpt_utils import load_checkpoint , save_checkpoint
39
+ from open_diloco .ckpt_utils import (
40
+ CKPT_PREFIX ,
41
+ CkptConfig ,
42
+ check_checkpoint_path_access ,
43
+ delete_old_checkpoints ,
44
+ get_diloco_rank_dir_name ,
45
+ load_checkpoint ,
46
+ save_checkpoint ,
47
+ )
42
48
from open_diloco .hivemind_diloco import AllReduceStrategy , DiLoCoOptimizer
43
49
44
50
58
64
TIMEOUT_NCCL_MINUTES = os .environ .get ("TIMEOUT_NCCL_MINUTES" , 120 )
59
65
TARGET_LAYER_ACTIVATIONS = ["self_attn" , "lm_head" ]
60
66
TEST_VOCAB_SIZE = 1024
61
- CKPT_PREFIX = "model_step"
62
67
63
68
64
69
# Function to initialize the distributed process group
@@ -71,33 +76,6 @@ def log(message):
71
76
logger .info (f"[rank { os .environ ['LOCAL_RANK' ]} ] { message } " )
72
77
73
78
74
- def check_checkpoint_path_access (checkpoint_path : str , rank : int , world_rank_hv : int | None = None ):
75
- if world_rank_hv :
76
- dummy_file_path = os .path .join (
77
- checkpoint_path , get_diloco_rank_dir_name (world_rank_hv ), f"dummy_file_{ rank } .txt"
78
- )
79
- else :
80
- dummy_file_path = os .path .join (checkpoint_path , f"dummy_file_{ rank } .txt" )
81
-
82
- with fsspec .open (dummy_file_path , "w" ) as f :
83
- f .write ("This is a dummy file for testing access." )
84
- gfs = GenericFileSystem ()
85
- gfs .rm (dummy_file_path )
86
-
87
-
88
- def get_diloco_rank_dir_name (world_rank_diloco : int ) -> str :
89
- return f"diloco_rank_{ world_rank_diloco } "
90
-
91
-
92
- def delete_old_checkpoints (checkpoint_path : str , topk : int ):
93
- fs = GenericFileSystem ()
94
- ckpt_files = [f for f in fs .ls (checkpoint_path , detail = False ) if filter_ckpt_files (f )]
95
- ckpt_files .sort (key = lambda x : int (x .split ("_" )[- 1 ]))
96
- for ckpt_file in ckpt_files [:- topk ]:
97
- log (f"Deleting old checkpoint { ckpt_file } " )
98
- fs .rm (ckpt_file , recursive = True )
99
-
100
-
101
79
class HvConfig (BaseConfig ):
102
80
outer_lr : float = 0.7
103
81
local_steps : int = 500
@@ -123,40 +101,6 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
123
101
return values
124
102
125
103
126
- def filter_ckpt_files (f ):
127
- if CKPT_PREFIX not in f :
128
- return False
129
- else :
130
- try :
131
- int (f .split ("_" )[- 1 ])
132
- return True
133
- except ValueError :
134
- return False
135
-
136
-
137
- class CkptConfig (BaseConfig ):
138
- resume : str | bool | None = None # if resume is a boolean, it means we should resume from the last checkpoint
139
- interval : int | None = None
140
- path : str = "outputs"
141
- topk : int | None = None # how many checkpoints to keep
142
-
143
- def get_resume_path (self ):
144
- if self .resume is None :
145
- raise ValueError ("Resume path is not set" )
146
- elif isinstance (self .resume , bool ):
147
- # Using fsspec to list directory contents
148
- fs = GenericFileSystem ()
149
- ckpt_files = [f for f in fs .ls (self .path , detail = False ) if filter_ckpt_files (f )]
150
-
151
- if len (ckpt_files ) == 0 :
152
- raise ValueError (f"No checkpoints found in { self .path } " )
153
-
154
- latest_ckpt = max (ckpt_files , key = lambda f : int (f .split ("_" )[- 1 ]))
155
- return latest_ckpt
156
-
157
- return self .resume
158
-
159
-
160
104
class Config (BaseConfig ):
161
105
path_model : str = "PrimeIntellect/llama-150m-fresh"
162
106
torch_compile : bool = True
@@ -559,7 +503,9 @@ def scheduler_fn(opt):
559
503
if local_rank == 0 :
560
504
# only the rank 0 deletes the checkpoints
561
505
if config .ckpt .topk is not None :
562
- delete_old_checkpoints (config .ckpt .path , config .ckpt .topk )
506
+ ckpt_deleted = delete_old_checkpoints (config .ckpt .path , config .ckpt .topk )
507
+ if ckpt_deleted :
508
+ log (f"Deleted old checkpoints: { ckpt_deleted } " )
563
509
564
510
loss_batch = 0
565
511
0 commit comments