Skip to content

Commit fc9cca8

Browse files
authored
Merge branch 'main' into yeyu/move_offline_eagle_to_online
2 parents 0c203e0 + b895dc5 commit fc9cca8

File tree

16 files changed

+169
-141
lines changed

16 files changed

+169
-141
lines changed

docs/source/guides/4_distillation.rst

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ a more powerful teacher model using :mod:`modelopt.torch.distill <modelopt.torch
1616
interaction between the two.
1717
#. **Distillation training**: Seamlessly use the meta-model in place of the original model and run
1818
the original script with only one additional line of code for loss calculation.
19-
#. **Checkpoint and re-load**: Save the model via :meth:`mto.save <modelopt.torch.opt.conversion.save>` and
20-
restore via :meth:`mto.restore <modelopt.torch.opt.conversion.restore>`. See :ref:`saving and restoring <save-restore>`
21-
to learn more.
19+
#. **Checkpoint and re-load**: Save the model via :meth:`mto.save <modelopt.torch.opt.conversion.save>`
20+
Note that restoring the model (via :meth:`mto.restore <modelopt.torch.opt.conversion.restore>`)
21+
will not reinstantiate the distillation meta-model, in order to avoid unpickling issues.
2222

2323
*To find out more about Distillation and related concepts, please refer to the below section*
2424
:ref:`Distillation Concepts <distillation-concepts>`.
@@ -44,7 +44,7 @@ Example usage:
4444
4545
# Configure and convert for distillation
4646
distillation_config = {
47-
# `teacher_model` is a model class or callable, or a tuple.
47+
# `teacher_model` is a model, model class, callable, or a tuple.
4848
# If a tuple, it must be of the form (model_cls_or_callable,) or
4949
# (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs).
5050
"teacher_model": teacher_model,
@@ -53,15 +53,9 @@ Example usage:
5353
}
5454
distillation_model = mtd.convert(model, mode=[("kd_loss", distillation_config)])
5555
56-
# Export model in original class form
56+
# Export model in original class, with only previously-present attributes
5757
model_exported = mtd.export(distillation_model)
5858
59-
.. note::
60-
The config requires a (non-lambda) Callable to return a teacher model in place of the model
61-
itself. This is to avoid re-saving the teacher state dict upon saving the Distillation
62-
meta model. Thus, the same callable must be available in the namespace when restoring via
63-
the :meth:`mto.restore <modelopt.torch.opt.conversion.restore>` utility.
64-
6559
.. tip::
6660
When training the student on a small corpus of ground truth data, consider using :class:`MFTLoss <modelopt.torch.distill.MFTLoss>` for to perform Minifinetuning in lieu of the standard
6761
:class:`LogitsDistillationLoss <modelopt.torch.distill.losses.LogitsDistillationLoss>`. This will allow the student to learn from the teacher's distribution while adapting to the new data, improving the specialization of the new data without overwriting teacher's general knowledge.
@@ -170,10 +164,12 @@ outputs in the same order as well:
170164
The intermediate outputs for the losses are captured by the
171165
:class:`DistillationModel <modelopt.torch.distill.distillation_model.DistillationModel>` and then the loss(es) are
172166
invoked using :meth:`DistillationModel.compute_kd_loss() <modelopt.torch.distill.distillation_model.DistillationModel.compute_kd_loss>`.
173-
If present, the original student's non-distillation loss is passed in as an argument.
167+
If present, the original student's non-distillation loss can be passed in as an argument.
174168

175169
Writing a custom loss function is often necessary, especially to handle outputs that need to be processed
176-
to obtain the logits and activations.
170+
to obtain the logits and activations. Additional arguments to the loss function can be passed in to
171+
:meth:`DistillationModel.compute_kd_loss() <modelopt.torch.distill.distillation_model.DistillationModel.compute_kd_loss>`
172+
as ``kwargs``.
177173

178174
Loss Balancer
179175
^^^^^^^^^^^^^

examples/llm_distill/README.md

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,9 @@ First obtain both a pretrained model to act as the teacher and a (usually smalle
3939
```python
4040
from transformers import AutoModelForCausalLM
4141

42-
# Define student
42+
# Define student & teacher
4343
student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
44-
45-
# Define callable which returns teacher
46-
def teacher_factory():
47-
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct")
48-
return teacher_model
44+
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct")
4945
```
5046

5147
### Set up the meta model
@@ -58,15 +54,15 @@ Please see an example Distillation setup below. This example assumes the outputs
5854
import modelopt.torch.distill as mtd
5955

6056
distillation_config = {
61-
"teacher_model": teacher_factory, # model initializer
57+
"teacher_model": teacher_model,
6258
"criterion": mtd.LogitsDistillationLoss(), # callable receiving student and teacher outputs, in order
6359
"loss_balancer": mtd.StaticLossBalancer(), # combines multiple losses; omit if only one distillation loss used
6460
}
6561

6662
distillation_model = mtd.convert(student_model, mode=[("kd_loss", distillation_config)])
6763
```
6864

69-
The `teacher_model` can be either a callable which returns an `nn.Module` or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed).
65+
The `teacher_model` can be either a `nn.Module`, a callable which returns an `nn.Module`, or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed).
7066

7167
See [Distillation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/4_distillation.html) for more info.
7268

@@ -158,35 +154,33 @@ Keep in mind the training loss of the distillation run is not directly comparabl
158154
### Train teacher
159155

160156
```bash
161-
accelerate launch --multi_gpu --mixed_precision bf16 main.py \
157+
accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
158+
main.py \
162159
--single_model \
163160
--teacher_name_or_path 'meta-llama/Llama-2-7b-hf' \
164161
--output_dir ./llama2-7b-sft \
165-
--logging_steps 5 \
166-
--max_steps 400 \
167-
--max_seq_length 2048 \
162+
--max_length 2048 \
168163
--per_device_train_batch_size 1 \
169164
--per_device_eval_batch_size 4 \
170-
--gradient_checkpointing True \
171-
--fsdp 'full_shard auto_wrap' \
172-
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer
165+
--max_steps 400 \
166+
--logging_steps 5
173167
```
174168

175169
### Distill teacher into student
176170

177171
```bash
178-
accelerate launch --multi_gpu --mixed_precision bf16 main.py \
172+
accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
173+
--fsdp_cpu_ram_efficient_loading False \
174+
--fsdp_activation_checkpointing False \
175+
main.py \
179176
--teacher_name_or_path ./llama2-7b-sft \
180177
--student_name_or_path 'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T' \
181178
--output_dir ./llama2-distill \
182-
--logging_steps 5 \
183-
--max_steps 200 \
184-
--max_seq_length 2048 \
179+
--max_length 2048 \
185180
--per_device_train_batch_size 1 \
186181
--per_device_eval_batch_size 4 \
187-
--gradient_checkpointing False \
188-
--fsdp 'full_shard auto_wrap' \
189-
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer
182+
--max_steps 200 \
183+
--logging_steps 5
190184
```
191185

192186
> [!NOTE]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: FSDP
4+
downcast_bf16: 'no'
5+
enable_cpu_affinity: false
6+
fsdp_config:
7+
fsdp_activation_checkpointing: true
8+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9+
fsdp_cpu_ram_efficient_loading: true
10+
fsdp_offload_params: false
11+
fsdp_reshard_after_forward: true
12+
fsdp_state_dict_type: SHARDED_STATE_DICT
13+
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
14+
fsdp_version: 2
15+
machine_rank: 0
16+
main_training_function: main
17+
mixed_precision: bf16
18+
num_machines: 1
19+
num_processes: gpu
20+
rdzv_backend: static
21+
same_network: true
22+
tpu_env: []
23+
tpu_use_cluster: false
24+
tpu_use_sudo: false
25+
use_cpu: false

examples/llm_distill/main.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
import torch.distributed
2323
import transformers
24-
from accelerate import PartialState
2524
from accelerate.logging import get_logger
2625
from transformers import AutoTokenizer
2726
from trl import SFTTrainer
@@ -48,38 +47,28 @@ class TrainingArguments(transformers.TrainingArguments):
4847
do_train: bool = True
4948
do_eval: bool = True
5049
save_strategy: str = "no"
51-
max_seq_length: int = 1024
50+
max_length: int = 1024
5251
optim: str = "adamw_torch"
5352
learning_rate: float = 1e-5
5453
lr_scheduler_type: str = "cosine"
5554
dataloader_drop_last: bool = True
5655
dataset_num_proc: int = 8
57-
dataset_batch_size: int = 500
5856
bf16: bool = True
5957
tf32: bool = True
6058

6159

6260
def llama_text_format_func(sample):
63-
texts = []
64-
for p, q, r in zip(sample["system_prompt"], sample["question"], sample["response"]):
65-
if not p:
66-
texts.append(f"<s>[INST] {q}[/INST]\n{r}</s>")
67-
else:
68-
texts.append(f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>")
69-
return texts
61+
p, q, r = sample["system_prompt"], sample["question"], sample["response"]
62+
if not p:
63+
return f"<s>[INST] {q}[/INST]\n{r}</s>"
64+
else:
65+
return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>"
7066

7167

7268
class KDSFTTrainer(SFTTrainer, KDTrainer):
7369
pass
7470

7571

76-
def _teacher_factory(model_name_or_path):
77-
return transformers.AutoModelForCausalLM.from_pretrained(
78-
model_name_or_path,
79-
device_map=PartialState().process_index,
80-
)
81-
82-
8372
def train():
8473
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
8574
model_args, training_args = parser.parse_args_into_dataclasses()
@@ -117,34 +106,31 @@ def train():
117106

118107
if model_args.single_model:
119108
logger.info("Loading single model only...")
120-
model = _teacher_factory(model_path)
109+
model = transformers.AutoModelForCausalLM.from_pretrained(
110+
model_path, dtype=torch.bfloat16 if training_args.bf16 else None
111+
)
121112
logger.info("Model loaded.")
122113
else:
123114
logger.info("Loading student model...")
124115
model = transformers.AutoModelForCausalLM.from_pretrained(
125-
model_args.student_name_or_path,
126-
device_map=PartialState().process_index,
116+
model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
127117
)
128118
logger.info("Student loaded.")
129119
# Load checkpoint
130120
logger.info("Loading teacher model and converting to Distillation model...")
121+
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
122+
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
123+
)
131124
kd_config = {
132-
"teacher_model": (
133-
_teacher_factory,
134-
(model_args.teacher_name_or_path,),
135-
{},
136-
),
125+
"teacher_model": teacher_model,
137126
"criterion": LMLogitsLoss(),
138-
"expose_minimal_state_dict": False, # FSDP forces us to disable this
139127
}
140128
model = mtd.convert(model, mode=[("kd_loss", kd_config)])
141129
logger.info("Models converted.")
142130

143131
# Fix problematic settings that logger.info excessive warnings
144132
model.generation_config.temperature = None
145133
model.generation_config.top_p = None
146-
if training_args.gradient_checkpointing:
147-
training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
148134

149135
# Trainer
150136
trainer_cls = SFTTrainer if model_args.single_model else KDSFTTrainer
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
pyarrow
2-
trl==0.13.0
2+
trl>=0.23.0

examples/onnx_ptq/evaluate.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ def main():
3535
help="""Path to the image classification ONNX model with input shape of
3636
[batch_size,3,224,224] and output shape of [1,1000]""",
3737
)
38+
parser.add_argument(
39+
"--engine_path",
40+
type=str,
41+
required=True,
42+
help="Path to the TensorRT engine",
43+
)
3844
parser.add_argument(
3945
"--imagenet_path", type=str, default=None, help="Path to the imagenet dataset"
4046
)
@@ -73,7 +79,10 @@ def main():
7379
client = RuntimeRegistry.get(deployment)
7480

7581
# Compile the ONNX model to TRT engine and create the device model
76-
compiled_model = client.ir_to_compiled(onnx_bytes)
82+
compilation_args = {
83+
"engine_path": args.engine_path,
84+
}
85+
compiled_model = client.ir_to_compiled(onnx_bytes, compilation_args)
7786
device_model = DeviceModel(client, compiled_model, metadata={})
7887

7988
top1_accuracy, top5_accuracy = 0.0, 0.0

modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def _update_dynamic_shapes(dynamic_shapes: dict, cmd: list[str]) -> None:
122122
def build_engine(
123123
onnx_bytes: OnnxBytes,
124124
trt_mode: str = TRTMode.FLOAT32,
125+
engine_path: Path | None = None,
125126
calib_cache: str | None = None,
126127
dynamic_shapes: dict | None = None,
127128
plugin_config: dict | None = None,
@@ -133,6 +134,7 @@ def build_engine(
133134
134135
Args:
135136
onnx_bytes: Data of the ONNX model stored as an OnnxBytes object.
137+
engine_path: Path to save the TensorRT engine.
136138
trt_mode: The precision with which the TensorRT engine will be built. Supported modes are:
137139
- TRTMode.FLOAT32
138140
- TRTMode.FLOAT16
@@ -202,22 +204,28 @@ def _build_command(
202204

203205
def _setup_files_and_paths(
204206
tmp_dir_path: Path,
207+
engine_path: Path | None,
205208
) -> tuple[Path, Path, Path | None, Path | None, Path]:
206209
tmp_onnx_dir = tmp_dir_path / "onnx"
207210
onnx_bytes.write_to_disk(str(tmp_onnx_dir))
208211
onnx_path = tmp_onnx_dir / f"{onnx_bytes.model_name}.onnx"
209212

210213
final_output_dir = Path(output_dir or Path(gettempdir()) / DEFAULT_ARTIFACT_DIR)
211214
final_output_dir.mkdir(parents=True, exist_ok=True)
212-
engine_path = final_output_dir / f"{onnx_bytes.model_name}.engine"
215+
engine_path = (
216+
Path(engine_path)
217+
if engine_path
218+
else final_output_dir / f"{onnx_bytes.model_name}.engine"
219+
)
220+
engine_path.parent.mkdir(parents=True, exist_ok=True)
213221
calib_cache_path = final_output_dir / "calib_cache" if calib_cache else None
214222
timing_cache_path = final_output_dir / "timing.cache"
215223

216224
return onnx_path, engine_path, calib_cache_path, timing_cache_path, final_output_dir
217225

218226
with TemporaryDirectory() as tmp_dir:
219227
onnx_path, engine_path, calib_cache_path, timing_cache_path, final_output_dir = (
220-
_setup_files_and_paths(Path(tmp_dir))
228+
_setup_files_and_paths(Path(tmp_dir), engine_path)
221229
)
222230
cmd = _build_command(onnx_path, engine_path, calib_cache_path, timing_cache_path)
223231

modelopt/torch/_deploy/_runtime/trt_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def _ir_to_compiled(
7373
7474
Args:
7575
ir_bytes: The ONNX model bytes.
76-
compilation_args: A dictionary of compilation arguments. Supported args: dynamic_shapes, plugin_config.
76+
compilation_args: A dictionary of compilation arguments.
77+
The following arguments are supported: dynamic_shapes, plugin_config, engine_path.
7778
7879
Returns:
7980
The compiled TRT engine bytes.
@@ -85,6 +86,7 @@ def _ir_to_compiled(
8586
onnx_bytes,
8687
dynamic_shapes=compilation_args.get("dynamic_shapes"), # type: ignore[union-attr]
8788
plugin_config=compilation_args.get("plugin_config"), # type: ignore[union-attr]
89+
engine_path=compilation_args.get("engine_path"), # type: ignore[union-attr]
8890
trt_mode=self.deployment["precision"],
8991
verbose=(self.deployment.get("verbose", "false").lower() == "true"),
9092
)

modelopt/torch/distill/config.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,18 @@
1616
"""Configurations for distillation modes."""
1717

1818
import warnings
19-
from collections.abc import Callable
2019
from typing import Any, Union
2120

2221
import pydantic
23-
import torch.nn as nn
2422
from torch.nn.modules.loss import _Loss as Loss
2523

2624
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
25+
from modelopt.torch.utils.network import ModelLike
2726

2827
from .loss_balancers import DistillationLossBalancer
2928

3029
__all__ = ["KDLossConfig"]
3130

32-
TeacherModel = type[nn.Module] | tuple | Callable
3331
Criterion = Union[Loss, dict[tuple[str, str], Loss]] # noqa: UP007
3432

3533

@@ -42,14 +40,13 @@ class KDLossConfig(ModeloptBaseConfig):
4240
# TODO: we should really think about a better to configure KDLossConfig
4341
model_config = pydantic.ConfigDict(extra="forbid", arbitrary_types_allowed=True)
4442

45-
teacher_model: TeacherModel | None = ModeloptField(
43+
teacher_model: ModelLike | None = ModeloptField(
4644
default=None,
4745
title="Teacher model",
4846
description=(
49-
"The class or callable or tuple to initialize the teacher model using"
47+
"The module, class, callable, or tuple to initialize the teacher model using"
5048
" :meth:`init_model_from_model_like"
51-
" <modelopt.torch.utils.network.init_model_from_model_like>`. This cannot already be an"
52-
" instance of nn.Module."
49+
" <modelopt.torch.utils.network.init_model_from_model_like>`."
5350
),
5451
)
5552
criterion: Criterion | None = ModeloptField(

0 commit comments

Comments
 (0)