1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import logging
1716import os
1817from dataclasses import dataclass
1918
19+ os .environ ["PYTORCH_CUDA_ALLOC_CONF" ] = "expandable_segments:True"
20+
2021import datasets
2122import torch
2223import torch .distributed
2930import modelopt .torch .opt as mto
3031from modelopt .torch .distill .plugins .huggingface import KDTrainer , LMLogitsLoss
3132
32- os .environ ["PYTORCH_CUDA_ALLOC_CONF" ] = "max_split_size_mb:512"
33-
34- logger = get_logger (__name__ )
35- logging .basicConfig (level = logging .INFO )
33+ logger = get_logger (__name__ , log_level = "INFO" )
3634
3735
3836@dataclass
@@ -69,6 +67,29 @@ class KDSFTTrainer(SFTTrainer, KDTrainer):
6967 pass
7068
7169
70+ def _save_model_fsdp_compat (
71+ self ,
72+ output_dir : str | None = None ,
73+ _internal_call : bool = False ,
74+ * args ,
75+ ** kwargs ,
76+ ):
77+ output_dir = output_dir or self .args .output_dir
78+ model = self .accelerator .unwrap_model (self .model )
79+ if not _internal_call and self .is_fsdp_enabled :
80+ state_dict = self .accelerator .get_state_dict (self .model )
81+ if self .accelerator .is_main_process :
82+ model .save_pretrained (
83+ output_dir ,
84+ is_main_process = self .accelerator .is_main_process ,
85+ save_function = self .accelerator .save ,
86+ state_dict = state_dict ,
87+ )
88+ self .processing_class .save_pretrained (output_dir )
89+ else :
90+ super (SFTTrainer , self ).save_model (output_dir , _internal_call , * args , ** kwargs )
91+
92+
7293def train ():
7394 parser = transformers .HfArgumentParser ((ModelArguments , TrainingArguments ))
7495 model_args , training_args = parser .parse_args_into_dataclasses ()
@@ -77,6 +98,9 @@ def train():
7798 # modelopt state will be saved automatically to "modelopt_state.pth"
7899 mto .enable_huggingface_checkpointing ()
79100
101+ # HACK: Fix FSDP2-incompatible save_model() function for SFTTrainer
102+ SFTTrainer .save_model = _save_model_fsdp_compat
103+
80104 # Set total batch size across all ranks to equal 64
81105 total_batch_size = 64
82106 num_accum_steps = total_batch_size / (
@@ -91,19 +115,22 @@ def train():
91115 f"Using { int (num_accum_steps )} grad accumulation steps for effective batchsize of { total_batch_size } ."
92116 )
93117
118+ # Dataset
94119 logger .info ("Loading dataset..." )
95120 dset = datasets .load_dataset ("Open-Orca/OpenOrca" , split = "train" )
96121 dset_splits = dset .train_test_split (train_size = 25600 , test_size = 1700 , seed = 420 )
97122 dset_train , dset_eval = dset_splits ["train" ], dset_splits ["test" ]
98123 logger .info ("Dataset loaded." )
99124
125+ # Tokenizer
100126 logger .info ("Loading tokenizer..." )
101127 model_path = model_args .teacher_name_or_path or model_args .student_name_or_path
102128 tokenizer = AutoTokenizer .from_pretrained (model_path , use_fast = True )
103129 tokenizer .pad_token = tokenizer .eos_token
104130 tokenizer .padding_side = "right"
105131 logger .info ("Tokenizer loaded." )
106132
133+ # Model
107134 if model_args .single_model :
108135 logger .info ("Loading single model only..." )
109136 model = transformers .AutoModelForCausalLM .from_pretrained (
0 commit comments