Skip to content

Commit 10e3e2b

Browse files
authored
Merge branch 'main' into jennifchen/cp_amax_sync
2 parents 34c11ef + 17439e6 commit 10e3e2b

File tree

24 files changed

+211
-573
lines changed

24 files changed

+211
-573
lines changed

.github/workflows/example_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ jobs:
6969
image: nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2.post2
7070
env:
7171
PIP_CONSTRAINT: "" # Disable pip constraint for upgrading packages
72+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
7273
steps: &example_steps
7374
- uses: actions/checkout@v4
7475
- uses: nv-gha-runners/setup-proxy-cache@main

.vscode/settings.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,7 @@
4242
"evenBetterToml.schema.enabled": false, // disable toml/json schema since we have custom fields
4343
"python.analysis.extraPaths": [
4444
"./tests/" // add tests to python path just like pytest does in pyproject.toml
45-
]
45+
],
46+
"git.alwaysSignOff": true,
47+
"git.enableCommitSigning": true,
4648
}

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
______________________________________________________________________
1717

18-
The **NVIDIA TensorRT Model Optimizer** (referred to as **Model Optimizer**, or **ModelOpt**) is a library comprising state-of-the-art model optimization [techniques](#techniques) including quantization, distillation, pruning, speculative decoding and sparsity to accelerate models.
18+
**NVIDIA TensorRT Model Optimizer** (referred to as **Model Optimizer**, or **ModelOpt**) is a library comprising state-of-the-art model optimization [techniques](#techniques) including quantization, distillation, pruning, speculative decoding and sparsity to accelerate models.
1919

2020
**[Input]** Model Optimizer currently supports inputs of a [Hugging Face](https://huggingface.co/), [PyTorch](https://github.com/pytorch/pytorch) or [ONNX](https://github.com/onnx/onnx) model.
2121

docs/source/guides/3_pruning.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ Following info will be printed before the pruning process is started:
190190
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
191191
┃ Constraint ┃ min ┃ centroid ┃ max ┃ max/min ratio ┃
192192
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
193-
│ flops │ 274.34M1.28G4.59G │ 16.73 │
193+
│ flops │ 548.68M2.56G9.18G │ 16.73 │
194194
│ params │ 2.70M │ 9.75M │ 25.50M │ 9.43 │
195195
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
196196
@@ -199,7 +199,7 @@ Following info will be printed before the pruning process is started:
199199
┃ ┃ ┃ Satisfiable ┃
200200
┃ Constraint ┃ Upper Bound ┃ Upper Bound ┃
201201
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
202-
│ flops │ 2.75G │ True │
202+
│ flops │ 5.50G │ True │
203203
└──────────────┴──────────────┴──────────────┘
204204
205205

docs/source/guides/7_nas.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ the search space together with your deployment constraints using
109109
110110
import torch
111111
112-
# Looking for a subnet with at most 2 GFLOPs
113-
constraints = {"flops": 2.0e9}
112+
# Looking for a subnet with at most 4 GFLOPs
113+
constraints = {"flops": 4.0e9}
114114
115115
# Measure FLOPs against dummy_input
116116
# Can be provided as a single tensor or tuple of input args to the model.
@@ -129,7 +129,7 @@ Following info will be printed:
129129
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
130130
┃ Constraint ┃ min ┃ centroid ┃ max ┃ max/min ratio ┃
131131
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
132-
│ flops │ 487.92M1.84G4.59G │ 9.40 │
132+
│ flops │ 975.84M3.68G9.18G │ 9.40 │
133133
│ params │ 4.84M │ 12.33M │ 25.50M │ 5.27 │
134134
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
135135
@@ -138,7 +138,7 @@ Following info will be printed:
138138
┃ ┃ ┃ Satisfiable ┃
139139
┃ Constraint ┃ Upper Bound ┃ Upper Bound ┃
140140
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
141-
│ flops │ 2.00G │ True │
141+
│ flops │ 4.00G │ True │
142142
└──────────────┴──────────────┴──────────────┘
143143
144144
Search Space Summary:
@@ -242,8 +242,8 @@ Below is an example of running search on an AutoNAS converted and trained model.
242242
# Specify the sample input including target data shape for FLOPs calculation.
243243
dummy_input = torch.randn(1, 3, 224, 224)
244244
245-
# Looking for a subnet with at most 2 GFLOPs
246-
search_constraints = {"flops": 2.0e9}
245+
# Looking for a subnet with at most 4 GFLOPs
246+
search_constraints = {"flops": 4.0e9}
247247
248248
# search_res (dict) contains state_dict / stats of the searcher
249249
searched_model, search_res = mtn.search(

examples/llm_ptq/example_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,7 @@ def get_model(
204204
if auto_model_module != AutoModelForCausalLM:
205205
model_kwargs2.pop("trust_remote_code", None)
206206
model_kwargs2["torch_dtype"] = torch_dtype
207-
# DeciLMForCausalLM does not support max_memory argument
208-
if "architectures" in hf_config and "DeciLMForCausalLM" in hf_config.architectures:
209-
model_kwargs2.pop("max_memory", None)
207+
model_kwargs2.pop("max_memory", None)
210208
model = from_config(hf_config, **model_kwargs2)
211209

212210
max_memory = get_max_memory()

examples/llm_ptq/hf_ptq.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ def main(args):
328328
model = model.language_model
329329
model_type = get_model_type(model)
330330

331+
if model_type == "phi4mm":
332+
warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.")
333+
331334
if args.sparsity_fmt != "dense":
332335
if args.batch_size == 0:
333336
# Sparse algorithm takes more GPU memory so we reduce the batch_size by 4.
@@ -478,9 +481,6 @@ def main(args):
478481
quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
479482
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
480483
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
481-
warnings.warn(
482-
"Please set the default input_mode to InputMode.LANGUAGE before quantizing."
483-
)
484484

485485
if not model_is_already_quantized or calibration_only:
486486
# Only run single sample for preview

examples/pruning/cifar_resnet.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@
489489
"* prune the model;\n",
490490
"* obtain a valid pytorch model that can be used for fine-tuning.\n",
491491
"\n",
492-
"Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 30M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `30e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n",
492+
"Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 60M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `60e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n",
493493
"\n",
494494
"Finally, we can store the pruned architecture and weights using [mto.save](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.save).\n",
495495
"\n",
@@ -529,7 +529,7 @@
529529
"┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
530530
"\u001b[1m \u001b[0m\u001b[1mConstraint \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmin \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mcentroid \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax/min ratio\u001b[0m\u001b[1m \u001b[0m┃\n",
531531
"┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
532-
"│ flops │ 24.33M27.57M40.55M │ 1.67 │\n",
532+
"│ flops │ 48.66M55.14M81.10M │ 1.67 │\n",
533533
"│ params │ 90.94K │ 141.63K │ 268.35K │ 2.95 │\n",
534534
"└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘\n",
535535
"\u001b[3m \u001b[0m\n",
@@ -538,7 +538,7 @@
538538
"\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSatisfiable \u001b[0m\u001b[1m \u001b[0m┃\n",
539539
"\u001b[1m \u001b[0m\u001b[1mConstraint \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\n",
540540
"┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n",
541-
"│ flops │ 30.00M │ True │\n",
541+
"│ flops │ 60.00M │ True │\n",
542542
"└──────────────┴──────────────┴──────────────┘\n",
543543
"\n",
544544
"\n",
@@ -618,7 +618,7 @@
618618
"name": "stdout",
619619
"output_type": "stream",
620620
"text": [
621-
"[best_subnet_constraints] = {'params': '173.88K', 'flops': '29.64M'}\n"
621+
"[best_subnet_constraints] = {'params': '173.88K', 'flops': '59.28M'}\n"
622622
]
623623
},
624624
{
@@ -656,7 +656,7 @@
656656
"pruned_model, _ = mtp.prune(\n",
657657
" model=resnet20(ckpt=\"resnet20.pth\"),\n",
658658
" mode=[(\"fastnas\", config)],\n",
659-
" constraints={\"flops\": 30e6},\n",
659+
" constraints={\"flops\": 60e6},\n",
660660
" dummy_input=dummy_input,\n",
661661
" config={\n",
662662
" \"data_loader\": train_loader,\n",
@@ -676,7 +676,7 @@
676676
"cell_type": "markdown",
677677
"metadata": {},
678678
"source": [
679-
"As we can see, the best subnet (29.6M FLOPs) has fitted our constraint of 30M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n",
679+
"As we can see, the best subnet (59.3M FLOPs) has fitted our constraint of 60M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n",
680680
"\n",
681681
"#### Restore the pruned subnet using [mto.restore](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.restore)"
682682
]
@@ -795,8 +795,8 @@
795795
"\n",
796796
"| Model | FLOPs | Params | Test Accuracy |\n",
797797
"| --------------- | ---------- | ---------- | ----------------- |\n",
798-
"| ResNet20 | 40.6M | 268k | 90.9% |\n",
799-
"| FastNAS subnet | 29.6M | 174k | 90.3% |\n",
798+
"| ResNet20 | 81.2M | 268k | 90.9% |\n",
799+
"| FastNAS subnet | 59.2M | 174k | 90.3% |\n",
800800
"\n",
801801
"As we see here, we have reduced the FLOPs and number of parameters which would also result in a improvement in latency with very little loss in accuracy. Good job!\n",
802802
"\n",

examples/speculative_decoding/eagle_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,21 @@
1919

2020
import torch
2121
import transformers
22+
from ar_validate import validate_ar
23+
from datasets import load_dataset
2224
from torch.utils.data import Dataset
25+
from transformers import TrainerCallback
2326
from transformers.trainer_pt_utils import LabelSmoother
2427

2528
from modelopt.torch.utils import print_rank_0
2629

30+
try:
31+
import wandb
32+
33+
wandb.init()
34+
except ImportError:
35+
wandb = None
36+
2737
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
2838

2939
REMOVE_THINK_CHAT_TEMPLATE = (
@@ -382,3 +392,24 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
382392
}
383393

384394
return batch
395+
396+
397+
class ARValidationCallback(TrainerCallback):
398+
def __init__(self, ar_validate_steps: int = 1000):
399+
self.ar_validate_steps = ar_validate_steps
400+
401+
def on_step_end(self, args, state, control, **kwargs):
402+
if self.ar_validate_steps <= 0:
403+
return control
404+
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
405+
print_rank_0("Running AR validation...")
406+
ars = validate_ar(
407+
model=kwargs["model"],
408+
tokenizer=kwargs["processing_class"],
409+
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
410+
device=kwargs["model"].device,
411+
)
412+
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
413+
if wandb:
414+
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
415+
return control

examples/speculative_decoding/main.py

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,15 @@
3636

3737
import torch
3838
import transformers
39-
from ar_validate import validate_ar
40-
from datasets import load_dataset
41-
from eagle_utils import make_eagle_supervised_data_module
39+
from eagle_utils import ARValidationCallback, make_eagle_supervised_data_module
4240
from medusa_utils import make_medusa_supervised_data_module
43-
from transformers import Trainer, TrainerCallback
41+
from transformers import Trainer
4442
from transformers.trainer_utils import get_last_checkpoint
4543

4644
import modelopt.torch.opt as mto
4745
import modelopt.torch.speculative as mtsp
4846
from modelopt.torch.utils import print_rank_0
4947

50-
try:
51-
import wandb
52-
53-
wandb.init()
54-
except ImportError:
55-
wandb = None
56-
5748
torch.manual_seed(0)
5849
mto.enable_huggingface_checkpointing()
5950

@@ -147,9 +138,8 @@ def train():
147138
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto")
148139
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
149140
else:
150-
model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {}
151141
model = transformers.AutoModelForCausalLM.from_pretrained(
152-
model_args.model_name_or_path, torch_dtype="auto", **model_kwargs
142+
model_args.model_name_or_path, torch_dtype="auto", device_map="cpu"
153143
)
154144
if use_offline_training:
155145
# When doing offline training, we need to set num_hidden_layers
@@ -231,34 +221,13 @@ def train():
231221
tokenizer, data_args, use_offline_training, max_length=training_args.training_seq_len
232222
)
233223

234-
class ARValidationCallback(TrainerCallback):
235-
def __init__(self, ar_validate_steps: int = 500):
236-
self.ar_validate_steps = ar_validate_steps
237-
238-
def on_step_end(self, args, state, control, **kwargs):
239-
if self.ar_validate_steps <= 0:
240-
return control
241-
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
242-
print_rank_0("Running AR validation...")
243-
ars = validate_ar(
244-
model=kwargs["model"],
245-
tokenizer=kwargs["processing_class"],
246-
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
247-
device=kwargs["model"].device,
248-
)
249-
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
250-
if wandb:
251-
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
252-
return control
253-
254224
trainer = Trainer(
255225
model=model,
256226
processing_class=tokenizer,
257227
args=training_args,
258228
callbacks=[ARValidationCallback(training_args.ar_validate_steps)],
259229
**data_module,
260230
)
261-
trainer._move_model_to_device(model, trainer.args.device)
262231

263232
# Manually enable this to return loss in eval
264233
trainer.can_return_loss = True

0 commit comments

Comments
 (0)