You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Export model in original class, with only previously-present attributes
57
57
model_exported = mtd.export(distillation_model)
58
58
59
-
.. note::
60
-
The config requires a (non-lambda) Callable to return a teacher model in place of the model
61
-
itself. This is to avoid re-saving the teacher state dict upon saving the Distillation
62
-
meta model. Thus, the same callable must be available in the namespace when restoring via
63
-
the :meth:`mto.restore <modelopt.torch.opt.conversion.restore>` utility.
64
-
65
59
.. tip::
66
60
When training the student on a small corpus of ground truth data, consider using :class:`MFTLoss <modelopt.torch.distill.MFTLoss>` for to perform Minifinetuning in lieu of the standard
67
61
:class:`LogitsDistillationLoss <modelopt.torch.distill.losses.LogitsDistillationLoss>`. This will allow the student to learn from the teacher's distribution while adapting to the new data, improving the specialization of the new data without overwriting teacher's general knowledge.
@@ -170,10 +164,12 @@ outputs in the same order as well:
170
164
The intermediate outputs for the losses are captured by the
171
165
:class:`DistillationModel <modelopt.torch.distill.distillation_model.DistillationModel>` and then the loss(es) are
172
166
invoked using :meth:`DistillationModel.compute_kd_loss() <modelopt.torch.distill.distillation_model.DistillationModel.compute_kd_loss>`.
173
-
If present, the original student's non-distillation loss is passed in as an argument.
167
+
If present, the original student's non-distillation loss can be passed in as an argument.
174
168
175
169
Writing a custom loss function is often necessary, especially to handle outputs that need to be processed
176
-
to obtain the logits and activations.
170
+
to obtain the logits and activations. Additional arguments to the loss function can be passed in to
The `teacher_model` can be either a callable which returns an `nn.Module` or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed).
65
+
The `teacher_model` can be either a `nn.Module`, a callable which returns an `nn.Module`, or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed).
70
66
71
67
See [Distillation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/4_distillation.html) for more info.
0 commit comments