1414
1515# Standard
1616from dataclasses import dataclass
17+ import os
18+
19+ # Third Party
20+ from transformers import (
21+ Trainer ,
22+ TrainerCallback ,
23+ TrainerControl ,
24+ TrainerState ,
25+ TrainingArguments ,
26+ )
27+ from transformers .trainer import TRAINING_ARGS_NAME
28+ from transformers .trainer_utils import PREFIX_CHECKPOINT_DIR
29+ import torch
1730
1831# Local
1932from .utils import ensure_nested_dataclasses_initialized , parsable_dataclass
2033
34+ is_recover_safetensors_from_dcp_available = True
35+ try :
36+ # Third Party
37+ from fms_acceleration_moe .utils import recover_safetensors_from_dcp
38+ except ImportError :
39+ is_recover_safetensors_from_dcp_available = False
40+
2141
2242@parsable_dataclass
2343@dataclass
@@ -34,3 +54,77 @@ class FastMoeConfig:
3454 def __post_init__ (self ):
3555 # ensure nested dataclasses initialized
3656 ensure_nested_dataclasses_initialized (self )
57+
58+
59+ def get_callbacks (** kwargs ):
60+ pretrained_model_name_or_path = kwargs .pop ("pretrained_model_name_or_path" )
61+ trainer = kwargs .pop ("trainer" )
62+ callbacks = []
63+ if is_recover_safetensors_from_dcp_available :
64+
65+ class ConvertAndSaveHFCheckpointAtEverySave (TrainerCallback ):
66+ def __init__ (self , pretrained_model_name_or_path : str , trainer : Trainer ):
67+ self .pretrained_model_name_or_path = pretrained_model_name_or_path
68+ self .trainer = trainer
69+
70+ def on_save (
71+ self ,
72+ args : TrainingArguments ,
73+ state : TrainerState ,
74+ control : TrainerControl ,
75+ ** kwargs ,
76+ ):
77+ """
78+ Save all HF files and convert dcp checkpoint to safetensors at every save operation.
79+ """
80+
81+ def checkpoint ():
82+ checkpoint_dir = os .path .join (
83+ args .output_dir ,
84+ f"{ PREFIX_CHECKPOINT_DIR } -{ state .global_step } " ,
85+ )
86+ hf_converted_output_dir = os .path .join (
87+ checkpoint_dir , "hf_converted_checkpoint"
88+ )
89+ if os .path .exists (hf_converted_output_dir ):
90+ # if the folder already exists
91+ # we return, since this is possible to happen
92+ # saving the checkpointing at the end of the training
93+ return
94+ os .mkdir (hf_converted_output_dir )
95+ try :
96+ recover_safetensors_from_dcp (
97+ checkpoint_dir ,
98+ self .pretrained_model_name_or_path ,
99+ hf_converted_output_dir ,
100+ )
101+ # save tokenizer
102+ if self .trainer .processing_class :
103+ self .trainer .processing_class .save_pretrained (
104+ hf_converted_output_dir
105+ )
106+ # save training args
107+ torch .save (
108+ args ,
109+ os .path .join (hf_converted_output_dir , TRAINING_ARGS_NAME ),
110+ )
111+ # save model config files
112+ self .trainer .model .config .save_pretrained (
113+ hf_converted_output_dir
114+ )
115+
116+ except Exception as e :
117+ raise ValueError (
118+ f"Failed to convert the checkpoint { checkpoint_dir } \
119+ to a HF compatible checkpoint"
120+ ) from e
121+
122+ if state .is_world_process_zero :
123+ checkpoint ()
124+
125+ callbacks .append (
126+ ConvertAndSaveHFCheckpointAtEverySave (
127+ pretrained_model_name_or_path , trainer
128+ )
129+ )
130+ return callbacks
0 commit comments