@@ -89,6 +89,15 @@ def get_diloco_rank_dir_name(world_rank_diloco: int) -> str:
89
89
return f"diloco_rank_{ world_rank_diloco } "
90
90
91
91
92
+ def delete_old_checkpoints (checkpoint_path : str , topk : int ):
93
+ fs = GenericFileSystem ()
94
+ ckpt_files = [f for f in fs .ls (checkpoint_path , detail = False ) if filter_ckpt_files (f )]
95
+ ckpt_files .sort (key = lambda x : int (x .split ("_" )[- 1 ]))
96
+ for ckpt_file in ckpt_files [:- topk ]:
97
+ log (f"Deleting old checkpoint { ckpt_file } " )
98
+ fs .rm (ckpt_file , recursive = True )
99
+
100
+
92
101
class HvConfig (BaseConfig ):
93
102
outer_lr : float = 0.7
94
103
local_steps : int = 500
@@ -114,31 +123,34 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
114
123
return values
115
124
116
125
126
+ def filter_ckpt_files (f ):
127
+ if CKPT_PREFIX not in f :
128
+ return False
129
+ else :
130
+ try :
131
+ int (f .split ("_" )[- 1 ])
132
+ return True
133
+ except ValueError :
134
+ return False
135
+
136
+
117
137
class CkptConfig (BaseConfig ):
118
138
resume : str | bool | None = None # if resume is a boolean, it means we should resume from the last checkpoint
119
139
interval : int | None = None
120
140
path : str = "outputs"
141
+ topk : int | None = None # how many checkpoints to keep
121
142
122
143
def get_resume_path (self ):
123
144
if self .resume is None :
124
145
raise ValueError ("Resume path is not set" )
125
146
elif isinstance (self .resume , bool ):
126
147
# Using fsspec to list directory contents
127
148
fs = GenericFileSystem ()
149
+ ckpt_files = [f for f in fs .ls (self .path , detail = False ) if filter_ckpt_files (f )]
128
150
129
- def filter_ckpt_files (f ):
130
- if CKPT_PREFIX not in f :
131
- return False
132
- else :
133
- try :
134
- int (f .split ("_" )[- 1 ])
135
- return True
136
- except ValueError :
137
- return False
151
+ if len (ckpt_files ) == 0 :
152
+ raise ValueError (f"No checkpoints found in { self .path } " )
138
153
139
- ckpt_files = [f for f in fs .ls (self .path , detail = False ) if filter_ckpt_files (f )]
140
- # Regex to extract numbers following the CKPT_PREFIX and an underscore
141
- # f is usually something like this "file:///hello/model_step_100000"
142
154
latest_ckpt = max (ckpt_files , key = lambda f : int (f .split ("_" )[- 1 ]))
143
155
return latest_ckpt
144
156
@@ -544,6 +556,11 @@ def scheduler_fn(opt):
544
556
save_global_state = rank == 0 ,
545
557
)
546
558
559
+ if local_rank == 0 :
560
+ # only the rank 0 deletes the checkpoints
561
+ if config .ckpt .topk is not None :
562
+ delete_old_checkpoints (config .ckpt .path , config .ckpt .topk )
563
+
547
564
loss_batch = 0
548
565
549
566
if config .max_steps is not None and real_step >= config .max_steps :
0 commit comments