22
22
from datasets .distributed import split_dataset_by_node
23
23
from fsspec .generic import GenericFileSystem
24
24
from torch .distributed import destroy_process_group , init_process_group
25
+
25
26
from torchdata .stateful_dataloader import StatefulDataLoader
26
27
from transformers import (
27
28
AutoTokenizer ,
@@ -69,26 +70,18 @@ def log(message):
69
70
logger .info (f"[rank { os .environ ['LOCAL_RANK' ]} ] { message } " )
70
71
71
72
72
- def get_ckpt_folder (checkpoint_path , training_date , project , run_id ):
73
- return os .path .join (checkpoint_path , training_date , project , run_id )
74
-
75
-
76
- def check_checkpoint_path_access (checkpoint_path : str , training_date , project , run_id , rank ):
77
- dummy_file_path = os .path .join (
78
- get_ckpt_folder (
79
- checkpoint_path = checkpoint_path ,
80
- training_date = training_date ,
81
- project = project ,
82
- run_id = run_id ,
83
- ),
84
- f"dummy_file_{ rank } .txt" ,
85
- )
73
+ def check_checkpoint_path_access (checkpoint_path : str , rank : int ):
74
+ dummy_file_path = os .path .join (checkpoint_path , f"dummy_file_{ rank } .txt" )
86
75
with fsspec .open (dummy_file_path , "w" ) as f :
87
76
f .write ("This is a dummy file for testing access." )
88
77
gfs = GenericFileSystem ()
89
78
gfs .rm (dummy_file_path )
90
79
91
80
81
+ def get_diloco_rank_dir_name (world_rank_diloco : int ) -> str :
82
+ return f"diloco_rank_{ world_rank_diloco } "
83
+
84
+
92
85
class HvConfig (BaseConfig ):
93
86
outer_lr : float = 0.7
94
87
local_steps : int = 500
@@ -202,10 +195,6 @@ def train(config: Config):
202
195
assert batch_size % config .per_device_train_batch_size == 0
203
196
gradient_accumulation_steps = batch_size // config .per_device_train_batch_size
204
197
205
- training_date = datetime .datetime .now ().strftime (
206
- "%Y-%m-%d"
207
- ) # we define the data at the beginning of training in case the training take several days
208
-
209
198
if config .hv is not None :
210
199
sharding_strategy = ShardingStrategy .NO_SHARD
211
200
log ("Hivemind is used, ShardingStrategy.NO_SHARD is used" )
@@ -232,7 +221,7 @@ def train(config: Config):
232
221
log_visible_maddrs (dht .get_visible_maddrs (), only_p2p = False )
233
222
234
223
if local_rank == 0 :
235
- check_checkpoint_path_access (config .checkpoint_path , training_date , config . project , run_id , rank )
224
+ check_checkpoint_path_access (config .checkpoint_path , rank )
236
225
237
226
# DataLoader preparation
238
227
tokenizer = AutoTokenizer .from_pretrained ("mistralai/Mistral-7B-v0.1" , use_fast = True )
@@ -290,7 +279,9 @@ def scheduler_fn(opt):
290
279
# Otherwise the world messenger will get lonely and hang
291
280
fake_optimizer = inner_optimizer (model .parameters ())
292
281
last_loss = load_checkpoint (
293
- checkpoint_path = config .resume_from_checkpoint ,
282
+ checkpoint_path = os .path .join (
283
+ config .resume_from_checkpoint , get_diloco_rank_dir_name (config .hv .world_rank )
284
+ ),
294
285
model = model ,
295
286
optimizer = fake_optimizer ,
296
287
)
@@ -329,7 +320,9 @@ def scheduler_fn(opt):
329
320
330
321
if config .resume_from_checkpoint :
331
322
last_loss = load_checkpoint (
332
- checkpoint_path = config .resume_from_checkpoint ,
323
+ checkpoint_path = os .path .join (
324
+ config .resume_from_checkpoint , get_diloco_rank_dir_name (config .hv .world_rank )
325
+ ),
333
326
model = model ,
334
327
optimizer = optimizer .inner_optimizer ,
335
328
scheduler = scheduler ,
@@ -470,16 +463,13 @@ def scheduler_fn(opt):
470
463
# Save checkpoint every 'checkpoint_interval' steps
471
464
if config .checkpoint_interval is not None and real_step % config .checkpoint_interval == 0 :
472
465
log (f"saving at step { real_step } , step { step + 1 } " )
473
- ckpt_path = os .path .join (
474
- get_ckpt_folder (config .checkpoint_path , training_date , config .project , run_id ),
475
- f"model_step_{ int (real_step )} " ,
476
- )
466
+ ckpt_path = os .path .join (config .checkpoint_path , f"model_step_{ int (real_step )} " )
477
467
478
468
if world_messenger_hv :
479
469
assert isinstance (optimizer , DiLoCoOptimizer )
480
470
with optimizer .tracker .pause_updates ():
481
471
save_checkpoint (
482
- checkpoint_path = ckpt_path ,
472
+ checkpoint_path = os . path . join ( ckpt_path , get_diloco_rank_dir_name ( config . hv . world_rank )) ,
483
473
model = model ,
484
474
optimizer = optimizer .inner_optimizer ,
485
475
scheduler = scheduler ,
0 commit comments