Skip to content

Commit df0882a

Browse files
authored
Fix saving issue on new transformers w FSDP2 (for Release 0.37) (#414)
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 3a4ad73 commit df0882a

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

examples/llm_distill/main.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import logging
1716
import os
1817
from dataclasses import dataclass
1918

19+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
20+
2021
import datasets
2122
import torch
2223
import torch.distributed
@@ -29,10 +30,7 @@
2930
import modelopt.torch.opt as mto
3031
from modelopt.torch.distill.plugins.huggingface import KDTrainer, LMLogitsLoss
3132

32-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
33-
34-
logger = get_logger(__name__)
35-
logging.basicConfig(level=logging.INFO)
33+
logger = get_logger(__name__, log_level="INFO")
3634

3735

3836
@dataclass
@@ -69,6 +67,29 @@ class KDSFTTrainer(SFTTrainer, KDTrainer):
6967
pass
7068

7169

70+
def _save_model_fsdp_compat(
71+
self,
72+
output_dir: str | None = None,
73+
_internal_call: bool = False,
74+
*args,
75+
**kwargs,
76+
):
77+
output_dir = output_dir or self.args.output_dir
78+
model = self.accelerator.unwrap_model(self.model)
79+
if not _internal_call and self.is_fsdp_enabled:
80+
state_dict = self.accelerator.get_state_dict(self.model)
81+
if self.accelerator.is_main_process:
82+
model.save_pretrained(
83+
output_dir,
84+
is_main_process=self.accelerator.is_main_process,
85+
save_function=self.accelerator.save,
86+
state_dict=state_dict,
87+
)
88+
self.processing_class.save_pretrained(output_dir)
89+
else:
90+
super(SFTTrainer, self).save_model(output_dir, _internal_call, *args, **kwargs)
91+
92+
7293
def train():
7394
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
7495
model_args, training_args = parser.parse_args_into_dataclasses()
@@ -77,6 +98,9 @@ def train():
7798
# modelopt state will be saved automatically to "modelopt_state.pth"
7899
mto.enable_huggingface_checkpointing()
79100

101+
# HACK: Fix FSDP2-incompatible save_model() function for SFTTrainer
102+
SFTTrainer.save_model = _save_model_fsdp_compat
103+
80104
# Set total batch size across all ranks to equal 64
81105
total_batch_size = 64
82106
num_accum_steps = total_batch_size / (
@@ -91,19 +115,22 @@ def train():
91115
f"Using {int(num_accum_steps)} grad accumulation steps for effective batchsize of {total_batch_size}."
92116
)
93117

118+
# Dataset
94119
logger.info("Loading dataset...")
95120
dset = datasets.load_dataset("Open-Orca/OpenOrca", split="train")
96121
dset_splits = dset.train_test_split(train_size=25600, test_size=1700, seed=420)
97122
dset_train, dset_eval = dset_splits["train"], dset_splits["test"]
98123
logger.info("Dataset loaded.")
99124

125+
# Tokenizer
100126
logger.info("Loading tokenizer...")
101127
model_path = model_args.teacher_name_or_path or model_args.student_name_or_path
102128
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
103129
tokenizer.pad_token = tokenizer.eos_token
104130
tokenizer.padding_side = "right"
105131
logger.info("Tokenizer loaded.")
106132

133+
# Model
107134
if model_args.single_model:
108135
logger.info("Loading single model only...")
109136
model = transformers.AutoModelForCausalLM.from_pretrained(
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pyarrow
2+
transformers<4.57
23
trl>=0.23.0

0 commit comments

Comments
 (0)