Skip to content

Commit b895dc5

Browse files
authored
Disable KD mode from saving problematic state (#320)
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 27a6821 commit b895dc5

File tree

12 files changed

+103
-132
lines changed

12 files changed

+103
-132
lines changed

docs/source/guides/4_distillation.rst

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ a more powerful teacher model using :mod:`modelopt.torch.distill <modelopt.torch
1616
interaction between the two.
1717
#. **Distillation training**: Seamlessly use the meta-model in place of the original model and run
1818
the original script with only one additional line of code for loss calculation.
19-
#. **Checkpoint and re-load**: Save the model via :meth:`mto.save <modelopt.torch.opt.conversion.save>` and
20-
restore via :meth:`mto.restore <modelopt.torch.opt.conversion.restore>`. See :ref:`saving and restoring <save-restore>`
21-
to learn more.
19+
#. **Checkpoint and re-load**: Save the model via :meth:`mto.save <modelopt.torch.opt.conversion.save>`
20+
Note that restoring the model (via :meth:`mto.restore <modelopt.torch.opt.conversion.restore>`)
21+
will not reinstantiate the distillation meta-model, in order to avoid unpickling issues.
2222

2323
*To find out more about Distillation and related concepts, please refer to the below section*
2424
:ref:`Distillation Concepts <distillation-concepts>`.
@@ -44,7 +44,7 @@ Example usage:
4444
4545
# Configure and convert for distillation
4646
distillation_config = {
47-
# `teacher_model` is a model class or callable, or a tuple.
47+
# `teacher_model` is a model, model class, callable, or a tuple.
4848
# If a tuple, it must be of the form (model_cls_or_callable,) or
4949
# (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs).
5050
"teacher_model": teacher_model,
@@ -53,15 +53,9 @@ Example usage:
5353
}
5454
distillation_model = mtd.convert(model, mode=[("kd_loss", distillation_config)])
5555
56-
# Export model in original class form
56+
# Export model in original class, with only previously-present attributes
5757
model_exported = mtd.export(distillation_model)
5858
59-
.. note::
60-
The config requires a (non-lambda) Callable to return a teacher model in place of the model
61-
itself. This is to avoid re-saving the teacher state dict upon saving the Distillation
62-
meta model. Thus, the same callable must be available in the namespace when restoring via
63-
the :meth:`mto.restore <modelopt.torch.opt.conversion.restore>` utility.
64-
6559
.. tip::
6660
When training the student on a small corpus of ground truth data, consider using :class:`MFTLoss <modelopt.torch.distill.MFTLoss>` for to perform Minifinetuning in lieu of the standard
6761
:class:`LogitsDistillationLoss <modelopt.torch.distill.losses.LogitsDistillationLoss>`. This will allow the student to learn from the teacher's distribution while adapting to the new data, improving the specialization of the new data without overwriting teacher's general knowledge.
@@ -170,10 +164,12 @@ outputs in the same order as well:
170164
The intermediate outputs for the losses are captured by the
171165
:class:`DistillationModel <modelopt.torch.distill.distillation_model.DistillationModel>` and then the loss(es) are
172166
invoked using :meth:`DistillationModel.compute_kd_loss() <modelopt.torch.distill.distillation_model.DistillationModel.compute_kd_loss>`.
173-
If present, the original student's non-distillation loss is passed in as an argument.
167+
If present, the original student's non-distillation loss can be passed in as an argument.
174168

175169
Writing a custom loss function is often necessary, especially to handle outputs that need to be processed
176-
to obtain the logits and activations.
170+
to obtain the logits and activations. Additional arguments to the loss function can be passed in to
171+
:meth:`DistillationModel.compute_kd_loss() <modelopt.torch.distill.distillation_model.DistillationModel.compute_kd_loss>`
172+
as ``kwargs``.
177173

178174
Loss Balancer
179175
^^^^^^^^^^^^^

examples/llm_distill/README.md

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,9 @@ First obtain both a pretrained model to act as the teacher and a (usually smalle
3939
```python
4040
from transformers import AutoModelForCausalLM
4141

42-
# Define student
42+
# Define student & teacher
4343
student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
44-
45-
# Define callable which returns teacher
46-
def teacher_factory():
47-
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct")
48-
return teacher_model
44+
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct")
4945
```
5046

5147
### Set up the meta model
@@ -58,15 +54,15 @@ Please see an example Distillation setup below. This example assumes the outputs
5854
import modelopt.torch.distill as mtd
5955

6056
distillation_config = {
61-
"teacher_model": teacher_factory, # model initializer
57+
"teacher_model": teacher_model,
6258
"criterion": mtd.LogitsDistillationLoss(), # callable receiving student and teacher outputs, in order
6359
"loss_balancer": mtd.StaticLossBalancer(), # combines multiple losses; omit if only one distillation loss used
6460
}
6561

6662
distillation_model = mtd.convert(student_model, mode=[("kd_loss", distillation_config)])
6763
```
6864

69-
The `teacher_model` can be either a callable which returns an `nn.Module` or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed).
65+
The `teacher_model` can be either a `nn.Module`, a callable which returns an `nn.Module`, or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed).
7066

7167
See [Distillation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/4_distillation.html) for more info.
7268

@@ -158,35 +154,33 @@ Keep in mind the training loss of the distillation run is not directly comparabl
158154
### Train teacher
159155

160156
```bash
161-
accelerate launch --multi_gpu --mixed_precision bf16 main.py \
157+
accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
158+
main.py \
162159
--single_model \
163160
--teacher_name_or_path 'meta-llama/Llama-2-7b-hf' \
164161
--output_dir ./llama2-7b-sft \
165-
--logging_steps 5 \
166-
--max_steps 400 \
167-
--max_seq_length 2048 \
162+
--max_length 2048 \
168163
--per_device_train_batch_size 1 \
169164
--per_device_eval_batch_size 4 \
170-
--gradient_checkpointing True \
171-
--fsdp 'full_shard auto_wrap' \
172-
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer
165+
--max_steps 400 \
166+
--logging_steps 5
173167
```
174168

175169
### Distill teacher into student
176170

177171
```bash
178-
accelerate launch --multi_gpu --mixed_precision bf16 main.py \
172+
accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
173+
--fsdp_cpu_ram_efficient_loading False \
174+
--fsdp_activation_checkpointing False \
175+
main.py \
179176
--teacher_name_or_path ./llama2-7b-sft \
180177
--student_name_or_path 'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T' \
181178
--output_dir ./llama2-distill \
182-
--logging_steps 5 \
183-
--max_steps 200 \
184-
--max_seq_length 2048 \
179+
--max_length 2048 \
185180
--per_device_train_batch_size 1 \
186181
--per_device_eval_batch_size 4 \
187-
--gradient_checkpointing False \
188-
--fsdp 'full_shard auto_wrap' \
189-
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer
182+
--max_steps 200 \
183+
--logging_steps 5
190184
```
191185

192186
> [!NOTE]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: FSDP
4+
downcast_bf16: 'no'
5+
enable_cpu_affinity: false
6+
fsdp_config:
7+
fsdp_activation_checkpointing: true
8+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9+
fsdp_cpu_ram_efficient_loading: true
10+
fsdp_offload_params: false
11+
fsdp_reshard_after_forward: true
12+
fsdp_state_dict_type: SHARDED_STATE_DICT
13+
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
14+
fsdp_version: 2
15+
machine_rank: 0
16+
main_training_function: main
17+
mixed_precision: bf16
18+
num_machines: 1
19+
num_processes: gpu
20+
rdzv_backend: static
21+
same_network: true
22+
tpu_env: []
23+
tpu_use_cluster: false
24+
tpu_use_sudo: false
25+
use_cpu: false

examples/llm_distill/main.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
import torch.distributed
2323
import transformers
24-
from accelerate import PartialState
2524
from accelerate.logging import get_logger
2625
from transformers import AutoTokenizer
2726
from trl import SFTTrainer
@@ -48,38 +47,28 @@ class TrainingArguments(transformers.TrainingArguments):
4847
do_train: bool = True
4948
do_eval: bool = True
5049
save_strategy: str = "no"
51-
max_seq_length: int = 1024
50+
max_length: int = 1024
5251
optim: str = "adamw_torch"
5352
learning_rate: float = 1e-5
5453
lr_scheduler_type: str = "cosine"
5554
dataloader_drop_last: bool = True
5655
dataset_num_proc: int = 8
57-
dataset_batch_size: int = 500
5856
bf16: bool = True
5957
tf32: bool = True
6058

6159

6260
def llama_text_format_func(sample):
63-
texts = []
64-
for p, q, r in zip(sample["system_prompt"], sample["question"], sample["response"]):
65-
if not p:
66-
texts.append(f"<s>[INST] {q}[/INST]\n{r}</s>")
67-
else:
68-
texts.append(f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>")
69-
return texts
61+
p, q, r = sample["system_prompt"], sample["question"], sample["response"]
62+
if not p:
63+
return f"<s>[INST] {q}[/INST]\n{r}</s>"
64+
else:
65+
return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>"
7066

7167

7268
class KDSFTTrainer(SFTTrainer, KDTrainer):
7369
pass
7470

7571

76-
def _teacher_factory(model_name_or_path):
77-
return transformers.AutoModelForCausalLM.from_pretrained(
78-
model_name_or_path,
79-
device_map=PartialState().process_index,
80-
)
81-
82-
8372
def train():
8473
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
8574
model_args, training_args = parser.parse_args_into_dataclasses()
@@ -117,34 +106,31 @@ def train():
117106

118107
if model_args.single_model:
119108
logger.info("Loading single model only...")
120-
model = _teacher_factory(model_path)
109+
model = transformers.AutoModelForCausalLM.from_pretrained(
110+
model_path, dtype=torch.bfloat16 if training_args.bf16 else None
111+
)
121112
logger.info("Model loaded.")
122113
else:
123114
logger.info("Loading student model...")
124115
model = transformers.AutoModelForCausalLM.from_pretrained(
125-
model_args.student_name_or_path,
126-
device_map=PartialState().process_index,
116+
model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
127117
)
128118
logger.info("Student loaded.")
129119
# Load checkpoint
130120
logger.info("Loading teacher model and converting to Distillation model...")
121+
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
122+
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
123+
)
131124
kd_config = {
132-
"teacher_model": (
133-
_teacher_factory,
134-
(model_args.teacher_name_or_path,),
135-
{},
136-
),
125+
"teacher_model": teacher_model,
137126
"criterion": LMLogitsLoss(),
138-
"expose_minimal_state_dict": False, # FSDP forces us to disable this
139127
}
140128
model = mtd.convert(model, mode=[("kd_loss", kd_config)])
141129
logger.info("Models converted.")
142130

143131
# Fix problematic settings that logger.info excessive warnings
144132
model.generation_config.temperature = None
145133
model.generation_config.top_p = None
146-
if training_args.gradient_checkpointing:
147-
training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
148134

149135
# Trainer
150136
trainer_cls = SFTTrainer if model_args.single_model else KDSFTTrainer
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
pyarrow
2-
trl==0.13.0
2+
trl>=0.23.0

modelopt/torch/distill/config.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,18 @@
1616
"""Configurations for distillation modes."""
1717

1818
import warnings
19-
from collections.abc import Callable
2019
from typing import Any, Union
2120

2221
import pydantic
23-
import torch.nn as nn
2422
from torch.nn.modules.loss import _Loss as Loss
2523

2624
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
25+
from modelopt.torch.utils.network import ModelLike
2726

2827
from .loss_balancers import DistillationLossBalancer
2928

3029
__all__ = ["KDLossConfig"]
3130

32-
TeacherModel = type[nn.Module] | tuple | Callable
3331
Criterion = Union[Loss, dict[tuple[str, str], Loss]] # noqa: UP007
3432

3533

@@ -42,14 +40,13 @@ class KDLossConfig(ModeloptBaseConfig):
4240
# TODO: we should really think about a better to configure KDLossConfig
4341
model_config = pydantic.ConfigDict(extra="forbid", arbitrary_types_allowed=True)
4442

45-
teacher_model: TeacherModel | None = ModeloptField(
43+
teacher_model: ModelLike | None = ModeloptField(
4644
default=None,
4745
title="Teacher model",
4846
description=(
49-
"The class or callable or tuple to initialize the teacher model using"
47+
"The module, class, callable, or tuple to initialize the teacher model using"
5048
" :meth:`init_model_from_model_like"
51-
" <modelopt.torch.utils.network.init_model_from_model_like>`. This cannot already be an"
52-
" instance of nn.Module."
49+
" <modelopt.torch.utils.network.init_model_from_model_like>`."
5350
),
5451
)
5552
criterion: Criterion | None = ModeloptField(

modelopt/torch/distill/mode.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ def restore(self) -> RestoreEntrypoint:
8383
@property
8484
def update_for_new_mode(self) -> UpdateEntrypoint:
8585
"""The mode's entrypoint for updating the models state for adding new mode."""
86-
return _update_kd_state_before_new_mode
86+
return _reset_kd_state_config
87+
88+
@property
89+
def update_for_save(self) -> UpdateEntrypoint:
90+
"""The mode's entrypoint for updating the models state before saving."""
91+
return _reset_kd_state_config
8792

8893

8994
@DistillModeRegistry.register_mode
@@ -171,16 +176,12 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType
171176

172177
def _restore_kd_model(model: nn.Module, config: KDLossConfig, metadata: MetadataDict) -> nn.Module:
173178
"""Function for restoring a previously convert model to a distillation meta-model."""
174-
# the metadata should be empty
175-
assert not metadata, "No metadata expected!"
179+
# NOTE: DistillationModel will purposely remain unrestored
180+
return model
176181

177-
return _convert_for_kd(model, config)[0]
178182

179-
180-
def _update_kd_state_before_new_mode(
181-
model: nn.Module, config: KDLossConfig, metadata: MetadataDict
182-
) -> None:
183-
"""Function for updating the model's state before new mode."""
183+
def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict):
184+
"""Function for resetting the state's config."""
184185
config.teacher_model = nn.Module
185186
config.criterion = Loss()
186187
config.loss_balancer = None
@@ -216,8 +217,5 @@ def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertRet
216217
def _restore_exported_student(
217218
model: nn.Module, config: ExportStudentConfig, metadata: MetadataDict
218219
) -> nn.Module:
219-
"""Function for restoring a previously exported distillation meta-model."""
220-
# no metadata is used by the mode
221-
assert not metadata, "No metadata expected!"
222-
223-
return _export_student(model, config)[0]
220+
# NOTE: DistillationModel was unrestored so this does nothing
221+
return model

modelopt/torch/distill/plugins/huggingface.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,6 @@ def save_model(
8080
state_dict=state_dict,
8181
)
8282
self.processing_class.save_pretrained(output_dir)
83-
if export_student:
84-
modelopt_state["modelopt_state_dict"] = [
85-
state
86-
for state in modelopt_state["modelopt_state_dict"]
87-
if "kd_loss" not in state and "export_student" not in state
88-
]
8983
torch.save(modelopt_state, f"{output_dir}/modelopt_state.pth")
9084
else:
9185
model = model.export() if export_student else model

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,6 @@ def _save_modelopt_state_with_weights(self):
174174
torch.distributed.barrier()
175175

176176
modelopt_state = mto.modelopt_state(self.model)
177-
# TODO: remove this from ModelOpt HF Trainer flows
178-
modelopt_state["modelopt_state_dict"] = [
179-
state
180-
for state in modelopt_state["modelopt_state_dict"]
181-
if "kd_loss" not in state and "export_student" not in state
182-
]
183177
modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(self.model)
184178

185179
if self.args.should_save:

tests/examples/llm_distill/test_llm_distill.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@
2121
def test_llama_distill(tiny_llama_path, tmp_path):
2222
run_example_command(
2323
[
24-
"accelerate", "launch", "--multi_gpu", "--mixed_precision", "bf16", "main.py",
24+
"accelerate", "launch", "--config-file", "./accelerate_config/fsdp2.yaml",
25+
"--fsdp_cpu_ram_efficient_loading", "False",
26+
"--fsdp_activation_checkpointing", "False",
27+
"main.py",
2528
"--teacher_name_or_path", tiny_llama_path,
2629
"--student_name_or_path", tiny_llama_path,
2730
"--output_dir", tmp_path,
28-
"--logging_steps", "5",
29-
"--max_steps", "10",
30-
"--max_seq_length", "1024",
31+
"--max_length", "1024",
3132
"--per_device_train_batch_size", "2",
3233
"--per_device_eval_batch_size", "8",
33-
"--gradient_checkpointing", "True",
34-
"--fsdp", "full_shard auto_wrap",
35-
"--fsdp_transformer_layer_cls_to_wrap", "LlamaDecoderLayer",
34+
"--max_steps", "10",
35+
"--logging_steps", "5",
3636
],
3737
"llm_distill",
3838
)

0 commit comments

Comments
 (0)