Skip to content

Commit 35f90d0

Browse files
authored
Revamp distillation HF example (#430)
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 46a9e49 commit 35f90d0

File tree

4 files changed

+41
-102
lines changed

4 files changed

+41
-102
lines changed

examples/llm_distill/README.md

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ First obtain both a pretrained model to act as the teacher and a (usually smalle
4949
from transformers import AutoModelForCausalLM
5050

5151
# Define student & teacher
52-
student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
53-
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct")
52+
student_model = AutoModelForCausalLM.from_pretrained("student-model-id-or-path")
53+
teacher_model = AutoModelForCausalLM.from_pretrained("teacher-model-id-or-path")
5454
```
5555

5656
### Set up the meta model
@@ -149,52 +149,27 @@ You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDI
149149

150150
## Knowledge Distillation (KD) for HuggingFace Models
151151

152-
In this e2e example we finetune Llama-2 models on the [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca)
153-
question-answer dataset as a minimal example to demonstrate a simple way of integrating Model Optimizer's KD feature.
152+
In this e2e example we finetune Llama-3.2 models on the [smol-smoltalk-Interaction-SFT](https://huggingface.co/datasets/ReactiveAI/smol-smoltalk-Interaction-SFT)
153+
dataset as a minimal example to demonstrate a simple way of integrating Model Optimizer's KD feature.
154154

155-
First we do supervised finetuning (SFT) of a Llama-2-7b on OpenOrca dataset as the teacher, then distill it into
156-
a 1B-parameter model.
157-
158-
Keep in mind the training loss of the distillation run is not directly comparable to the training loss of the teacher run.
155+
We replace normal supervised finetuning (SFT) of a Llama-3.2-1B base model by distilling information from Llama-3.2-3B-Instruct which has already been instruction-finetuned.
159156

160157
> [!NOTE]
161158
> We can fit the following in memory using [FSDP](https://huggingface.co/docs/accelerate/en/usage_guides/fsdp) enabled on 8x RTX 6000 (total ~400GB VRAM)
162159
163-
### Train teacher
164-
165-
```bash
166-
accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
167-
main.py \
168-
--single_model \
169-
--teacher_name_or_path 'meta-llama/Llama-2-7b-hf' \
170-
--output_dir ./llama2-7b-sft \
171-
--max_length 2048 \
172-
--per_device_train_batch_size 1 \
173-
--per_device_eval_batch_size 4 \
174-
--max_steps 400 \
175-
--logging_steps 5
176-
```
177-
178-
### Distill teacher into student
179-
180160
```bash
181161
accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
182-
--fsdp_cpu_ram_efficient_loading False \
183-
--fsdp_activation_checkpointing False \
184162
main.py \
185-
--teacher_name_or_path ./llama2-7b-sft \
186-
--student_name_or_path 'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T' \
187-
--output_dir ./llama2-distill \
163+
--teacher_name_or_path 'meta-llama/Llama-3.2-3B-Instruct' \
164+
--student_name_or_path 'meta-llama/Llama-3.2-1B' \
165+
--output_dir ./llama3.2-distill \
188166
--max_length 2048 \
189-
--per_device_train_batch_size 1 \
190-
--per_device_eval_batch_size 4 \
167+
--per_device_train_batch_size 4 \
168+
--per_device_eval_batch_size 8 \
191169
--max_steps 200 \
192170
--logging_steps 5
193171
```
194172

195-
> [!NOTE]
196-
> If you receive a `RuntimeError: unable to open file <...> in read-only mode: No such file or directory` simply re-run the command a second time.
197-
198173
## Resources
199174

200175
- 📅 [Roadmap](https://github.com/NVIDIA/TensorRT-Model-Optimizer/issues/146)

examples/llm_distill/accelerate_config/fsdp2.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ distributed_type: FSDP
44
downcast_bf16: 'no'
55
enable_cpu_affinity: false
66
fsdp_config:
7-
fsdp_activation_checkpointing: true
7+
fsdp_activation_checkpointing: false
88
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9-
fsdp_cpu_ram_efficient_loading: true
9+
fsdp_cpu_ram_efficient_loading: false
1010
fsdp_offload_params: false
1111
fsdp_reshard_after_forward: true
1212
fsdp_state_dict_type: SHARDED_STATE_DICT

examples/llm_distill/main.py

Lines changed: 29 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
class ModelArguments:
3838
teacher_name_or_path: str | None = None
3939
student_name_or_path: str | None = None
40-
single_model: bool = False
4140

4241

4342
@dataclass
@@ -55,41 +54,20 @@ class TrainingArguments(transformers.TrainingArguments):
5554
tf32: bool = True
5655

5756

58-
def llama_text_format_func(sample):
59-
p, q, r = sample["system_prompt"], sample["question"], sample["response"]
60-
if not p:
61-
return f"<s>[INST] {q}[/INST]\n{r}</s>"
62-
else:
63-
return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>"
57+
def _format_smoltalk_chat_template(sample, tokenizer):
58+
# smol-smoltalk-Interaction-SFT dataset has "query" and "answer" fields
59+
# Convert them to messages format and use tokenizer's apply_chat_template
60+
messages = [
61+
{"role": "user", "content": sample["query"]},
62+
{"role": "assistant", "content": sample["answer"]},
63+
]
64+
return tokenizer.apply_chat_template(messages, tokenize=False)
6465

6566

6667
class KDSFTTrainer(SFTTrainer, KDTrainer):
6768
pass
6869

6970

70-
def _save_model_fsdp_compat(
71-
self,
72-
output_dir: str | None = None,
73-
_internal_call: bool = False,
74-
*args,
75-
**kwargs,
76-
):
77-
output_dir = output_dir or self.args.output_dir
78-
model = self.accelerator.unwrap_model(self.model)
79-
if not _internal_call and self.is_fsdp_enabled:
80-
state_dict = self.accelerator.get_state_dict(self.model)
81-
if self.accelerator.is_main_process:
82-
model.save_pretrained(
83-
output_dir,
84-
is_main_process=self.accelerator.is_main_process,
85-
save_function=self.accelerator.save,
86-
state_dict=state_dict,
87-
)
88-
self.processing_class.save_pretrained(output_dir)
89-
else:
90-
super(SFTTrainer, self).save_model(output_dir, _internal_call, *args, **kwargs)
91-
92-
9371
def train():
9472
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
9573
model_args, training_args = parser.parse_args_into_dataclasses()
@@ -98,9 +76,6 @@ def train():
9876
# modelopt state will be saved automatically to "modelopt_state.pth"
9977
mto.enable_huggingface_checkpointing()
10078

101-
# HACK: Fix FSDP2-incompatible save_model() function for SFTTrainer
102-
SFTTrainer.save_model = _save_model_fsdp_compat
103-
10479
# Set total batch size across all ranks to equal 64
10580
total_batch_size = 64
10681
num_accum_steps = total_batch_size / (
@@ -117,8 +92,8 @@ def train():
11792

11893
# Dataset
11994
logger.info("Loading dataset...")
120-
dset = datasets.load_dataset("Open-Orca/OpenOrca", split="train")
121-
dset_splits = dset.train_test_split(train_size=25600, test_size=1700, seed=420)
95+
dset = datasets.load_dataset("ReactiveAI/smol-smoltalk-Interaction-SFT", split="train")
96+
dset_splits = dset.train_test_split(train_size=12800, test_size=1280, seed=420)
12297
dset_train, dset_eval = dset_splits["train"], dset_splits["test"]
12398
logger.info("Dataset loaded.")
12499

@@ -131,42 +106,34 @@ def train():
131106
logger.info("Tokenizer loaded.")
132107

133108
# Model
134-
if model_args.single_model:
135-
logger.info("Loading single model only...")
136-
model = transformers.AutoModelForCausalLM.from_pretrained(
137-
model_path, dtype=torch.bfloat16 if training_args.bf16 else None
138-
)
139-
logger.info("Model loaded.")
140-
else:
141-
logger.info("Loading student model...")
142-
model = transformers.AutoModelForCausalLM.from_pretrained(
143-
model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
144-
)
145-
logger.info("Student loaded.")
146-
# Load checkpoint
147-
logger.info("Loading teacher model and converting to Distillation model...")
148-
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
149-
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
150-
)
151-
kd_config = {
152-
"teacher_model": teacher_model,
153-
"criterion": LMLogitsLoss(),
154-
}
155-
model = mtd.convert(model, mode=[("kd_loss", kd_config)])
156-
logger.info("Models converted.")
109+
logger.info("Loading student model...")
110+
model = transformers.AutoModelForCausalLM.from_pretrained(
111+
model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
112+
)
113+
logger.info("Student loaded.")
114+
# Load checkpoint
115+
logger.info("Loading teacher model and converting to Distillation model...")
116+
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
117+
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
118+
)
119+
kd_config = {
120+
"teacher_model": teacher_model,
121+
"criterion": LMLogitsLoss(),
122+
}
123+
model = mtd.convert(model, mode=[("kd_loss", kd_config)])
124+
logger.info("Models converted.")
157125

158126
# Fix problematic settings that logger.info excessive warnings
159127
model.generation_config.temperature = None
160128
model.generation_config.top_p = None
161129

162130
# Trainer
163-
trainer_cls = SFTTrainer if model_args.single_model else KDSFTTrainer
164-
trainer = trainer_cls(
131+
trainer = KDSFTTrainer(
165132
model,
166133
training_args,
167134
train_dataset=dset_train,
168135
eval_dataset=dset_eval,
169-
formatting_func=llama_text_format_func,
136+
formatting_func=lambda sample: _format_smoltalk_chat_template(sample, tokenizer),
170137
processing_class=tokenizer,
171138
)
172139

@@ -186,8 +153,7 @@ def train():
186153
# Save checkpoint
187154
logger.info("Saving checkpoint...")
188155
trainer.save_state()
189-
kwargs = {"export_student": True} if not model_args.single_model else {}
190-
trainer.save_model(trainer.args.output_dir, **kwargs)
156+
trainer.save_model(trainer.args.output_dir, export_student=True)
191157
logger.info("Checkpoint saved.")
192158

193159

tests/examples/llm_distill/test_llm_distill.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ def test_llama_distill(tiny_llama_path, tmp_path):
2222
run_example_command(
2323
[
2424
"accelerate", "launch", "--config-file", "./accelerate_config/fsdp2.yaml",
25-
"--fsdp_cpu_ram_efficient_loading", "False",
26-
"--fsdp_activation_checkpointing", "False",
2725
"main.py",
2826
"--teacher_name_or_path", tiny_llama_path,
2927
"--student_name_or_path", tiny_llama_path,

0 commit comments

Comments
 (0)