Skip to content

Commit 2774065

Browse files
lucas-venturamreso
andauthored
Improve model checkpoint saving logic (meta-llama#691)
Co-authored-by: Matthias Reso <[email protected]>
1 parent 2501f51 commit 2774065

File tree

1 file changed

+49
-45
lines changed

1 file changed

+49
-45
lines changed

src/llama_recipes/utils/train_utils.py

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -220,70 +220,74 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
220220

221221
# Update the learning rate as needed
222222
lr_scheduler.step()
223+
should_save_model = train_config.save_model
223224
if train_config.run_validation:
224225
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
225226
if train_config.save_metrics:
226227
val_step_loss.extend(temp_val_loss)
227228
val_step_perplexity.extend(temp_step_perplexity)
228-
229-
checkpoint_start_time = time.perf_counter()
230-
if train_config.save_model and eval_epoch_loss < best_val_loss:
229+
should_save_model = train_config.save_model and eval_epoch_loss < best_val_loss
230+
231+
checkpoint_start_time = time.perf_counter()
232+
if should_save_model:
233+
if train_config.enable_fsdp:
234+
dist.barrier()
235+
if train_config.use_peft:
231236
if train_config.enable_fsdp:
232-
dist.barrier()
233-
if train_config.use_peft:
234-
if train_config.enable_fsdp:
235-
if rank==0:
236-
print(f"we are about to save the PEFT modules")
237-
else:
237+
if rank==0:
238238
print(f"we are about to save the PEFT modules")
239-
save_peft_checkpoint(model, train_config.output_dir)
240-
if train_config.enable_fsdp:
241-
if rank==0:
242-
print(f"PEFT modules are saved in {train_config.output_dir} directory")
243-
else:
239+
else:
240+
print(f"we are about to save the PEFT modules")
241+
save_peft_checkpoint(model, train_config.output_dir)
242+
if train_config.enable_fsdp:
243+
if rank==0:
244244
print(f"PEFT modules are saved in {train_config.output_dir} directory")
245-
246245
else:
247-
if not train_config.enable_fsdp:
248-
save_model_checkpoint(model, train_config.output_dir)
249-
250-
elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
251-
print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
246+
print(f"PEFT modules are saved in {train_config.output_dir} directory")
247+
248+
else:
249+
if not train_config.enable_fsdp:
250+
save_model_checkpoint(model, train_config.output_dir)
251+
252+
elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
253+
print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
254+
print("=====================================================")
255+
save_fsdp_model_checkpoint_full(
256+
model, optimizer, rank, train_config, epoch=epoch
257+
)
258+
259+
if train_config.save_optimizer:
260+
print(" Saving the FSDP optimizer using FULL_STATE_DICT")
252261
print("=====================================================")
253-
save_fsdp_model_checkpoint_full(
262+
save_optimizer_checkpoint(
254263
model, optimizer, rank, train_config, epoch=epoch
255264
)
256-
257-
if train_config.save_optimizer:
258-
print(" Saving the FSDP optimizer using FULL_STATE_DICT")
259-
print("=====================================================")
260-
save_optimizer_checkpoint(
261-
model, optimizer, rank, train_config, epoch=epoch
262-
)
263-
264-
elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
265-
266-
if train_config.save_optimizer:
267-
print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
268-
print("=====================================================")
269-
save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
270-
else:
271-
print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
272-
print("=====================================================")
273-
save_model_and_optimizer_sharded(model, rank, train_config)
265+
266+
elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
274267

275-
276-
if train_config.enable_fsdp:
277-
dist.barrier()
278-
checkpoint_end_time = time.perf_counter() - checkpoint_start_time
279-
checkpoint_times.append(checkpoint_end_time)
268+
if train_config.save_optimizer:
269+
print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
270+
print("=====================================================")
271+
save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
272+
else:
273+
print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
274+
print("=====================================================")
275+
save_model_and_optimizer_sharded(model, rank, train_config)
276+
277+
278+
if train_config.enable_fsdp:
279+
dist.barrier()
280+
checkpoint_end_time = time.perf_counter() - checkpoint_start_time
281+
checkpoint_times.append(checkpoint_end_time)
282+
283+
if train_config.run_validation:
280284
if eval_epoch_loss < best_val_loss:
281285
best_val_loss = eval_epoch_loss
282286
if train_config.enable_fsdp:
283287
if rank==0:
284288
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
285289
else:
286-
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
290+
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
287291
val_loss.append(float(best_val_loss))
288292
val_prep.append(float(eval_ppl))
289293
if train_config.enable_fsdp:

0 commit comments

Comments
 (0)