Skip to content

Commit be7c7c0

Browse files
committed
Update KD docs
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent cf07926 commit be7c7c0

File tree

3 files changed

+21
-34
lines changed

3 files changed

+21
-34
lines changed

docs/source/guides/4_distillation.rst

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ a more powerful teacher model using :mod:`modelopt.torch.distill <modelopt.torch
1616
interaction between the two.
1717
#. **Distillation training**: Seamlessly use the meta-model in place of the original model and run
1818
the original script with only one additional line of code for loss calculation.
19-
#. **Checkpoint and re-load**: Save the model via :meth:`mto.save <modelopt.torch.opt.conversion.save>` and
20-
restore via :meth:`mto.restore <modelopt.torch.opt.conversion.restore>`. See :ref:`saving and restoring <save-restore>`
21-
to learn more.
19+
#. **Checkpoint and re-load**: Save the model via :meth:`mto.save <modelopt.torch.opt.conversion.save>`
20+
Note that restoring the model (via :meth:`mto.restore <modelopt.torch.opt.conversion.restore>`)
21+
will not reinstantiate the distillation meta-model, in order to avoid unpickling issues.
2222

2323
*To find out more about Distillation and related concepts, please refer to the below section*
2424
:ref:`Distillation Concepts <distillation-concepts>`.
@@ -44,7 +44,7 @@ Example usage:
4444
4545
# Configure and convert for distillation
4646
distillation_config = {
47-
# `teacher_model` is a model class or callable, or a tuple.
47+
# `teacher_model` is a model, model class, callable, or a tuple.
4848
# If a tuple, it must be of the form (model_cls_or_callable,) or
4949
# (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs).
5050
"teacher_model": teacher_model,
@@ -53,15 +53,9 @@ Example usage:
5353
}
5454
distillation_model = mtd.convert(model, mode=[("kd_loss", distillation_config)])
5555
56-
# Export model in original class form
56+
# Export model in original class, with only previously-present attributes
5757
model_exported = mtd.export(distillation_model)
5858
59-
.. note::
60-
The config requires a (non-lambda) Callable to return a teacher model in place of the model
61-
itself. This is to avoid re-saving the teacher state dict upon saving the Distillation
62-
meta model. Thus, the same callable must be available in the namespace when restoring via
63-
the :meth:`mto.restore <modelopt.torch.opt.conversion.restore>` utility.
64-
6559
.. tip::
6660
When training the student on a small corpus of ground truth data, consider using :class:`MFTLoss <modelopt.torch.distill.MFTLoss>` for to perform Minifinetuning in lieu of the standard
6761
:class:`LogitsDistillationLoss <modelopt.torch.distill.losses.LogitsDistillationLoss>`. This will allow the student to learn from the teacher's distribution while adapting to the new data, improving the specialization of the new data without overwriting teacher's general knowledge.
@@ -170,10 +164,12 @@ outputs in the same order as well:
170164
The intermediate outputs for the losses are captured by the
171165
:class:`DistillationModel <modelopt.torch.distill.distillation_model.DistillationModel>` and then the loss(es) are
172166
invoked using :meth:`DistillationModel.compute_kd_loss() <modelopt.torch.distill.distillation_model.DistillationModel.compute_kd_loss>`.
173-
If present, the original student's non-distillation loss is passed in as an argument.
167+
If present, the original student's non-distillation loss can be passed in as an argument.
174168

175169
Writing a custom loss function is often necessary, especially to handle outputs that need to be processed
176-
to obtain the logits and activations.
170+
to obtain the logits and activations. Additional arguments to the loss function can be passed in to
171+
:meth:`DistillationModel.compute_kd_loss() <modelopt.torch.distill.distillation_model.DistillationModel.compute_kd_loss>`
172+
as ``kwargs``.
177173

178174
Loss Balancer
179175
^^^^^^^^^^^^^

examples/llm_distill/README.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,9 @@ First obtain both a pretrained model to act as the teacher and a (usually smalle
3939
```python
4040
from transformers import AutoModelForCausalLM
4141

42-
# Define student
42+
# Define student & teacher
4343
student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
44-
45-
# Define callable which returns teacher
46-
def teacher_factory():
47-
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct")
48-
return teacher_model
44+
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct")
4945
```
5046

5147
### Set up the meta model
@@ -58,15 +54,15 @@ Please see an example Distillation setup below. This example assumes the outputs
5854
import modelopt.torch.distill as mtd
5955

6056
distillation_config = {
61-
"teacher_model": teacher_factory, # model initializer
57+
"teacher_model": teacher_model,
6258
"criterion": mtd.LogitsDistillationLoss(), # callable receiving student and teacher outputs, in order
6359
"loss_balancer": mtd.StaticLossBalancer(), # combines multiple losses; omit if only one distillation loss used
6460
}
6561

6662
distillation_model = mtd.convert(student_model, mode=[("kd_loss", distillation_config)])
6763
```
6864

69-
The `teacher_model` can be either a callable which returns an `nn.Module` or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed).
65+
The `teacher_model` can be either a `nn.Module`, a callable which returns an `nn.Module`, or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed).
7066

7167
See [Distillation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/4_distillation.html) for more info.
7268

examples/llm_distill/main.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,6 @@ class KDSFTTrainer(SFTTrainer, KDTrainer):
7373
pass
7474

7575

76-
def _teacher_factory(model_name_or_path):
77-
return transformers.AutoModelForCausalLM.from_pretrained(
78-
model_name_or_path,
79-
device_map=PartialState().process_index,
80-
)
81-
82-
8376
def train():
8477
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
8578
model_args, training_args = parser.parse_args_into_dataclasses()
@@ -117,7 +110,9 @@ def train():
117110

118111
if model_args.single_model:
119112
logger.info("Loading single model only...")
120-
model = _teacher_factory(model_path)
113+
model = transformers.AutoModelForCausalLM.from_pretrained(
114+
model_path, device_map=PartialState().process_index
115+
)
121116
logger.info("Model loaded.")
122117
else:
123118
logger.info("Loading student model...")
@@ -128,12 +123,12 @@ def train():
128123
logger.info("Student loaded.")
129124
# Load checkpoint
130125
logger.info("Loading teacher model and converting to Distillation model...")
126+
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
127+
model_args.teacher_name_or_path,
128+
device_map=PartialState().process_index,
129+
)
131130
kd_config = {
132-
"teacher_model": (
133-
_teacher_factory,
134-
(model_args.teacher_name_or_path,),
135-
{},
136-
),
131+
"teacher_model": teacher_model,
137132
"criterion": LMLogitsLoss(),
138133
"expose_minimal_state_dict": False, # FSDP forces us to disable this
139134
}

0 commit comments

Comments
 (0)