Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions docs/source/guides/4_distillation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ a more powerful teacher model using :mod:`modelopt.torch.distill <modelopt.torch
interaction between the two.
#. **Distillation training**: Seamlessly use the meta-model in place of the original model and run
the original script with only one additional line of code for loss calculation.
#. **Checkpoint and re-load**: Save the model via :meth:`mto.save <modelopt.torch.opt.conversion.save>` and
restore via :meth:`mto.restore <modelopt.torch.opt.conversion.restore>`. See :ref:`saving and restoring <save-restore>`
to learn more.
#. **Checkpoint and re-load**: Save the model via :meth:`mto.save <modelopt.torch.opt.conversion.save>`
Note that restoring the model (via :meth:`mto.restore <modelopt.torch.opt.conversion.restore>`)
will not reinstantiate the distillation meta-model, in order to avoid unpickling issues.

*To find out more about Distillation and related concepts, please refer to the below section*
:ref:`Distillation Concepts <distillation-concepts>`.
Expand All @@ -44,7 +44,7 @@ Example usage:

# Configure and convert for distillation
distillation_config = {
# `teacher_model` is a model class or callable, or a tuple.
# `teacher_model` is a model, model class, callable, or a tuple.
# If a tuple, it must be of the form (model_cls_or_callable,) or
# (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs).
"teacher_model": teacher_model,
Expand All @@ -53,15 +53,9 @@ Example usage:
}
distillation_model = mtd.convert(model, mode=[("kd_loss", distillation_config)])

# Export model in original class form
# Export model in original class, with only previously-present attributes
model_exported = mtd.export(distillation_model)

.. note::
The config requires a (non-lambda) Callable to return a teacher model in place of the model
itself. This is to avoid re-saving the teacher state dict upon saving the Distillation
meta model. Thus, the same callable must be available in the namespace when restoring via
the :meth:`mto.restore <modelopt.torch.opt.conversion.restore>` utility.

.. tip::
When training the student on a small corpus of ground truth data, consider using :class:`MFTLoss <modelopt.torch.distill.MFTLoss>` for to perform Minifinetuning in lieu of the standard
:class:`LogitsDistillationLoss <modelopt.torch.distill.losses.LogitsDistillationLoss>`. This will allow the student to learn from the teacher's distribution while adapting to the new data, improving the specialization of the new data without overwriting teacher's general knowledge.
Expand Down Expand Up @@ -170,10 +164,12 @@ outputs in the same order as well:
The intermediate outputs for the losses are captured by the
:class:`DistillationModel <modelopt.torch.distill.distillation_model.DistillationModel>` and then the loss(es) are
invoked using :meth:`DistillationModel.compute_kd_loss() <modelopt.torch.distill.distillation_model.DistillationModel.compute_kd_loss>`.
If present, the original student's non-distillation loss is passed in as an argument.
If present, the original student's non-distillation loss can be passed in as an argument.

Writing a custom loss function is often necessary, especially to handle outputs that need to be processed
to obtain the logits and activations.
to obtain the logits and activations. Additional arguments to the loss function can be passed in to
:meth:`DistillationModel.compute_kd_loss() <modelopt.torch.distill.distillation_model.DistillationModel.compute_kd_loss>`
as ``kwargs``.

Loss Balancer
^^^^^^^^^^^^^
Expand Down
38 changes: 16 additions & 22 deletions examples/llm_distill/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,9 @@ First obtain both a pretrained model to act as the teacher and a (usually smalle
```python
from transformers import AutoModelForCausalLM

# Define student
# Define student & teacher
student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

# Define callable which returns teacher
def teacher_factory():
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct")
return teacher_model
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct")
```

### Set up the meta model
Expand All @@ -58,15 +54,15 @@ Please see an example Distillation setup below. This example assumes the outputs
import modelopt.torch.distill as mtd

distillation_config = {
"teacher_model": teacher_factory, # model initializer
"teacher_model": teacher_model,
"criterion": mtd.LogitsDistillationLoss(), # callable receiving student and teacher outputs, in order
"loss_balancer": mtd.StaticLossBalancer(), # combines multiple losses; omit if only one distillation loss used
}

distillation_model = mtd.convert(student_model, mode=[("kd_loss", distillation_config)])
```

The `teacher_model` can be either a callable which returns an `nn.Module` or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed).
The `teacher_model` can be either a `nn.Module`, a callable which returns an `nn.Module`, or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed).

See [Distillation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/4_distillation.html) for more info.

Expand Down Expand Up @@ -158,35 +154,33 @@ Keep in mind the training loss of the distillation run is not directly comparabl
### Train teacher

```bash
accelerate launch --multi_gpu --mixed_precision bf16 main.py \
accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
main.py \
--single_model \
--teacher_name_or_path 'meta-llama/Llama-2-7b-hf' \
--output_dir ./llama2-7b-sft \
--logging_steps 5 \
--max_steps 400 \
--max_seq_length 2048 \
--max_length 2048 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 4 \
--gradient_checkpointing True \
--fsdp 'full_shard auto_wrap' \
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer
--max_steps 400 \
--logging_steps 5
```

### Distill teacher into student

```bash
accelerate launch --multi_gpu --mixed_precision bf16 main.py \
accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
--fsdp_cpu_ram_efficient_loading False \
--fsdp_activation_checkpointing False \
main.py \
--teacher_name_or_path ./llama2-7b-sft \
--student_name_or_path 'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T' \
--output_dir ./llama2-distill \
--logging_steps 5 \
--max_steps 200 \
--max_seq_length 2048 \
--max_length 2048 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 4 \
--gradient_checkpointing False \
--fsdp 'full_shard auto_wrap' \
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer
--max_steps 200 \
--logging_steps 5
```

> [!NOTE]
Expand Down
25 changes: 25 additions & 0 deletions examples/llm_distill/accelerate_config/fsdp2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: gpu
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
42 changes: 14 additions & 28 deletions examples/llm_distill/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch
import torch.distributed
import transformers
from accelerate import PartialState
from accelerate.logging import get_logger
from transformers import AutoTokenizer
from trl import SFTTrainer
Expand All @@ -48,38 +47,28 @@ class TrainingArguments(transformers.TrainingArguments):
do_train: bool = True
do_eval: bool = True
save_strategy: str = "no"
max_seq_length: int = 1024
max_length: int = 1024
optim: str = "adamw_torch"
learning_rate: float = 1e-5
lr_scheduler_type: str = "cosine"
dataloader_drop_last: bool = True
dataset_num_proc: int = 8
dataset_batch_size: int = 500
bf16: bool = True
tf32: bool = True


def llama_text_format_func(sample):
texts = []
for p, q, r in zip(sample["system_prompt"], sample["question"], sample["response"]):
if not p:
texts.append(f"<s>[INST] {q}[/INST]\n{r}</s>")
else:
texts.append(f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>")
return texts
p, q, r = sample["system_prompt"], sample["question"], sample["response"]
if not p:
return f"<s>[INST] {q}[/INST]\n{r}</s>"
else:
return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>"


class KDSFTTrainer(SFTTrainer, KDTrainer):
pass


def _teacher_factory(model_name_or_path):
return transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map=PartialState().process_index,
)


def train():
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()
Expand Down Expand Up @@ -117,34 +106,31 @@ def train():

if model_args.single_model:
logger.info("Loading single model only...")
model = _teacher_factory(model_path)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_path, dtype=torch.bfloat16 if training_args.bf16 else None
)
logger.info("Model loaded.")
else:
logger.info("Loading student model...")
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.student_name_or_path,
device_map=PartialState().process_index,
model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
)
logger.info("Student loaded.")
# Load checkpoint
logger.info("Loading teacher model and converting to Distillation model...")
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
)
Comment on lines +121 to +123
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Same kwarg bug for teacher load: use torch_dtype

Keep teacher/student loads consistent.

-        teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
-            model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
-        )
+        teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
+            model_args.teacher_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
)
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.teacher_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None
)
🤖 Prompt for AI Agents
In examples/llm_distill/main.py around lines 121 to 123, the teacher model is
loaded using the incorrect keyword argument dtype= when calling
transformers.AutoModelForCausalLM.from_pretrained; change that to torch_dtype=
and pass torch.bfloat16 if training_args.bf16 else None so the teacher load
matches the student load and uses the correct HF transformers parameter.

kd_config = {
"teacher_model": (
_teacher_factory,
(model_args.teacher_name_or_path,),
{},
),
"teacher_model": teacher_model,
"criterion": LMLogitsLoss(),
"expose_minimal_state_dict": False, # FSDP forces us to disable this
}
model = mtd.convert(model, mode=[("kd_loss", kd_config)])
logger.info("Models converted.")

# Fix problematic settings that logger.info excessive warnings
model.generation_config.temperature = None
model.generation_config.top_p = None
if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}

# Trainer
trainer_cls = SFTTrainer if model_args.single_model else KDSFTTrainer
Expand Down
2 changes: 1 addition & 1 deletion examples/llm_distill/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pyarrow
trl==0.13.0
trl>=0.23.0
11 changes: 4 additions & 7 deletions modelopt/torch/distill/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,18 @@
"""Configurations for distillation modes."""

import warnings
from collections.abc import Callable
from typing import Any, Union

import pydantic
import torch.nn as nn
from torch.nn.modules.loss import _Loss as Loss

from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
from modelopt.torch.utils.network import ModelLike

from .loss_balancers import DistillationLossBalancer

__all__ = ["KDLossConfig"]

TeacherModel = type[nn.Module] | tuple | Callable
Criterion = Union[Loss, dict[tuple[str, str], Loss]] # noqa: UP007


Expand All @@ -42,14 +40,13 @@ class KDLossConfig(ModeloptBaseConfig):
# TODO: we should really think about a better to configure KDLossConfig
model_config = pydantic.ConfigDict(extra="forbid", arbitrary_types_allowed=True)

teacher_model: TeacherModel | None = ModeloptField(
teacher_model: ModelLike | None = ModeloptField(
default=None,
title="Teacher model",
description=(
"The class or callable or tuple to initialize the teacher model using"
"The module, class, callable, or tuple to initialize the teacher model using"
" :meth:`init_model_from_model_like"
" <modelopt.torch.utils.network.init_model_from_model_like>`. This cannot already be an"
" instance of nn.Module."
" <modelopt.torch.utils.network.init_model_from_model_like>`."
),
)
criterion: Criterion | None = ModeloptField(
Expand Down
26 changes: 12 additions & 14 deletions modelopt/torch/distill/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ def restore(self) -> RestoreEntrypoint:
@property
def update_for_new_mode(self) -> UpdateEntrypoint:
"""The mode's entrypoint for updating the models state for adding new mode."""
return _update_kd_state_before_new_mode
return _reset_kd_state_config

@property
def update_for_save(self) -> UpdateEntrypoint:
"""The mode's entrypoint for updating the models state before saving."""
return _reset_kd_state_config


@DistillModeRegistry.register_mode
Expand Down Expand Up @@ -171,16 +176,12 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType

def _restore_kd_model(model: nn.Module, config: KDLossConfig, metadata: MetadataDict) -> nn.Module:
"""Function for restoring a previously convert model to a distillation meta-model."""
# the metadata should be empty
assert not metadata, "No metadata expected!"
# NOTE: DistillationModel will purposely remain unrestored
return model

return _convert_for_kd(model, config)[0]


def _update_kd_state_before_new_mode(
model: nn.Module, config: KDLossConfig, metadata: MetadataDict
) -> None:
"""Function for updating the model's state before new mode."""
def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict):
"""Function for resetting the state's config."""
config.teacher_model = nn.Module
config.criterion = Loss()
config.loss_balancer = None
Expand Down Expand Up @@ -216,8 +217,5 @@ def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertRet
def _restore_exported_student(
model: nn.Module, config: ExportStudentConfig, metadata: MetadataDict
) -> nn.Module:
"""Function for restoring a previously exported distillation meta-model."""
# no metadata is used by the mode
assert not metadata, "No metadata expected!"

return _export_student(model, config)[0]
# NOTE: DistillationModel was unrestored so this does nothing
return model
6 changes: 0 additions & 6 deletions modelopt/torch/distill/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,6 @@ def save_model(
state_dict=state_dict,
)
self.processing_class.save_pretrained(output_dir)
if export_student:
modelopt_state["modelopt_state_dict"] = [
state
for state in modelopt_state["modelopt_state_dict"]
if "kd_loss" not in state and "export_student" not in state
]
torch.save(modelopt_state, f"{output_dir}/modelopt_state.pth")
else:
model = model.export() if export_student else model
Expand Down
6 changes: 0 additions & 6 deletions modelopt/torch/quantization/plugins/transformers_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,6 @@ def _save_modelopt_state_with_weights(self):
torch.distributed.barrier()

modelopt_state = mto.modelopt_state(self.model)
# TODO: remove this from ModelOpt HF Trainer flows
modelopt_state["modelopt_state_dict"] = [
state
for state in modelopt_state["modelopt_state_dict"]
if "kd_loss" not in state and "export_student" not in state
]
modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(self.model)

if self.args.should_save:
Expand Down
14 changes: 7 additions & 7 deletions tests/examples/llm_distill/test_llm_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@
def test_llama_distill(tiny_llama_path, tmp_path):
run_example_command(
[
"accelerate", "launch", "--multi_gpu", "--mixed_precision", "bf16", "main.py",
"accelerate", "launch", "--config-file", "./accelerate_config/fsdp2.yaml",
"--fsdp_cpu_ram_efficient_loading", "False",
"--fsdp_activation_checkpointing", "False",
"main.py",
"--teacher_name_or_path", tiny_llama_path,
"--student_name_or_path", tiny_llama_path,
"--output_dir", tmp_path,
"--logging_steps", "5",
"--max_steps", "10",
"--max_seq_length", "1024",
"--max_length", "1024",
"--per_device_train_batch_size", "2",
"--per_device_eval_batch_size", "8",
"--gradient_checkpointing", "True",
"--fsdp", "full_shard auto_wrap",
"--fsdp_transformer_layer_cls_to_wrap", "LlamaDecoderLayer",
"--max_steps", "10",
"--logging_steps", "5",
],
"llm_distill",
)
Loading