@@ -220,70 +220,74 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
220
220
221
221
# Update the learning rate as needed
222
222
lr_scheduler .step ()
223
+ should_save_model = train_config .save_model
223
224
if train_config .run_validation :
224
225
eval_ppl , eval_epoch_loss , temp_val_loss , temp_step_perplexity = evaluation (model , train_config , eval_dataloader , local_rank , tokenizer , wandb_run )
225
226
if train_config .save_metrics :
226
227
val_step_loss .extend (temp_val_loss )
227
228
val_step_perplexity .extend (temp_step_perplexity )
228
-
229
- checkpoint_start_time = time .perf_counter ()
230
- if train_config .save_model and eval_epoch_loss < best_val_loss :
229
+ should_save_model = train_config .save_model and eval_epoch_loss < best_val_loss
230
+
231
+ checkpoint_start_time = time .perf_counter ()
232
+ if should_save_model :
233
+ if train_config .enable_fsdp :
234
+ dist .barrier ()
235
+ if train_config .use_peft :
231
236
if train_config .enable_fsdp :
232
- dist .barrier ()
233
- if train_config .use_peft :
234
- if train_config .enable_fsdp :
235
- if rank == 0 :
236
- print (f"we are about to save the PEFT modules" )
237
- else :
237
+ if rank == 0 :
238
238
print (f"we are about to save the PEFT modules" )
239
- save_peft_checkpoint ( model , train_config . output_dir )
240
- if train_config . enable_fsdp :
241
- if rank == 0 :
242
- print ( f"PEFT modules are saved in { train_config .output_dir } directory" )
243
- else :
239
+ else :
240
+ print ( f"we are about to save the PEFT modules" )
241
+ save_peft_checkpoint ( model , train_config . output_dir )
242
+ if train_config .enable_fsdp :
243
+ if rank == 0 :
244
244
print (f"PEFT modules are saved in { train_config .output_dir } directory" )
245
-
246
245
else :
247
- if not train_config .enable_fsdp :
248
- save_model_checkpoint (model , train_config .output_dir )
249
-
250
- elif fsdp_config .checkpoint_type == StateDictType .FULL_STATE_DICT :
251
- print (" Saving the FSDP model checkpoint using FULL_STATE_DICT" )
246
+ print (f"PEFT modules are saved in { train_config .output_dir } directory" )
247
+
248
+ else :
249
+ if not train_config .enable_fsdp :
250
+ save_model_checkpoint (model , train_config .output_dir )
251
+
252
+ elif fsdp_config .checkpoint_type == StateDictType .FULL_STATE_DICT :
253
+ print (" Saving the FSDP model checkpoint using FULL_STATE_DICT" )
254
+ print ("=====================================================" )
255
+ save_fsdp_model_checkpoint_full (
256
+ model , optimizer , rank , train_config , epoch = epoch
257
+ )
258
+
259
+ if train_config .save_optimizer :
260
+ print (" Saving the FSDP optimizer using FULL_STATE_DICT" )
252
261
print ("=====================================================" )
253
- save_fsdp_model_checkpoint_full (
262
+ save_optimizer_checkpoint (
254
263
model , optimizer , rank , train_config , epoch = epoch
255
264
)
256
-
257
- if train_config .save_optimizer :
258
- print (" Saving the FSDP optimizer using FULL_STATE_DICT" )
259
- print ("=====================================================" )
260
- save_optimizer_checkpoint (
261
- model , optimizer , rank , train_config , epoch = epoch
262
- )
263
-
264
- elif fsdp_config .checkpoint_type == StateDictType .SHARDED_STATE_DICT :
265
-
266
- if train_config .save_optimizer :
267
- print (" Saving the FSDP model checkpoints using SHARDED_STATE_DICT" )
268
- print ("=====================================================" )
269
- save_model_and_optimizer_sharded (model , rank , train_config , optim = optimizer )
270
- else :
271
- print (" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT" )
272
- print ("=====================================================" )
273
- save_model_and_optimizer_sharded (model , rank , train_config )
265
+
266
+ elif fsdp_config .checkpoint_type == StateDictType .SHARDED_STATE_DICT :
274
267
275
-
276
- if train_config .enable_fsdp :
277
- dist .barrier ()
278
- checkpoint_end_time = time .perf_counter () - checkpoint_start_time
279
- checkpoint_times .append (checkpoint_end_time )
268
+ if train_config .save_optimizer :
269
+ print (" Saving the FSDP model checkpoints using SHARDED_STATE_DICT" )
270
+ print ("=====================================================" )
271
+ save_model_and_optimizer_sharded (model , rank , train_config , optim = optimizer )
272
+ else :
273
+ print (" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT" )
274
+ print ("=====================================================" )
275
+ save_model_and_optimizer_sharded (model , rank , train_config )
276
+
277
+
278
+ if train_config .enable_fsdp :
279
+ dist .barrier ()
280
+ checkpoint_end_time = time .perf_counter () - checkpoint_start_time
281
+ checkpoint_times .append (checkpoint_end_time )
282
+
283
+ if train_config .run_validation :
280
284
if eval_epoch_loss < best_val_loss :
281
285
best_val_loss = eval_epoch_loss
282
286
if train_config .enable_fsdp :
283
287
if rank == 0 :
284
288
print (f"best eval loss on epoch { epoch + 1 } is { best_val_loss } " )
285
289
else :
286
- print (f"best eval loss on epoch { epoch + 1 } is { best_val_loss } " )
290
+ print (f"best eval loss on epoch { epoch + 1 } is { best_val_loss } " )
287
291
val_loss .append (float (best_val_loss ))
288
292
val_prep .append (float (eval_ppl ))
289
293
if train_config .enable_fsdp :
0 commit comments