Skip to content

Commit 6dffcd0

Browse files
Add minitron pruning and distillation guidelines in pruning readme (#419)
Signed-off-by: Keval Morabia <[email protected]>
1 parent ee19a7e commit 6dffcd0

File tree

2 files changed

+91
-5
lines changed

2 files changed

+91
-5
lines changed

CHANGELOG.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Model Optimizer Changelog (Linux)
22
=================================
33

4-
0.39 (2025-10-xx)
4+
0.39 (2025-11-xx)
55
^^^^^^^^^^^^^^^^^
66

77
**Deprecations**
@@ -12,7 +12,11 @@ Model Optimizer Changelog (Linux)
1212
- Add LoRA mode support for MCore in a new peft submodule: ``modelopt.torch.peft.update_model(model, LORA_CFG)``.
1313
- Support PTQ and fakequant in vLLM for fast evaluation of arbitrary quantization formats. See ``examples/vllm_serve`` for more details.
1414

15-
0.37 (2025-09-xx)
15+
**Documentation**
16+
17+
- Add general guidelines for Minitron pruning and distillation. See `examples/pruning/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/pruning#pruning-guidelines>`_ for more details.
18+
19+
0.37 (2025-10-08)
1620
^^^^^^^^^^^^^^^^^
1721

1822
**Deprecations**

examples/pruning/README.md

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ This section focuses on applying Model Optimizer's state-of-the-art complementar
1717
| Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | |
1818
| Getting Started | Learn how to use the pruning API | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html)\] |
1919
| Support Matrix | View the support matrix to see available pruning algorithms and their compatibility with different models and frameworks | \[[Link](#support-matrix)\] | |
20+
| Pruning Guidelines | Guidelines for choosing how and how much to prune for best results | \[[Link](#pruning-guidelines)\] | |
21+
| Examples | Examples of different pruning methods | \[[Link](#examples)\] | |
2022
| Resources | Extra links to relevant resources | \[[Link](#resources)\] | |
2123

2224
</div>
@@ -93,6 +95,84 @@ If your model parameters are already sorted, you can skip the sorting step by se
9395

9496
> *<sup>1.</sup>Only Pipeline Parallel models are supported. Hugging Face models can be converted to NeMo format and used subsequently.*
9597
98+
## Pruning Guidelines
99+
100+
### Minitron
101+
102+
This section provides recommendations for choosing pruning strategies and distillation hyperparameters for Minitron pruning to help achieve the best latency-accuracy trade-offs.
103+
104+
#### Depth Pruning
105+
106+
Depth pruning reduces the number of layers (`num_layers`) in the model.
107+
108+
**Advantages:**
109+
110+
- Simpler to configure - only 1 parameter to tune
111+
- Faster inference than width-pruned models at a fixed number of parameters
112+
113+
**Recommendations:**
114+
115+
- Up to **1/3rd parameter reduction** can generally result in a model above the Pareto frontier with good latency-accuracy trade-off (when using a good quality dataset for distillation with ~80-100B tokens)
116+
- For pruning **>50%**, use iterative pruning: compress by 30%, perform distillation, then compress again
117+
118+
**Examples:**
119+
120+
- [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) (`num_layers=36`) → 6B (`num_layers=24`)
121+
- [Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) (`num_layers=32`) → 4.5B (`num_layers=16`)
122+
123+
#### Width Pruning
124+
125+
Width pruning reduces model dimensions per layer such as `hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, and `mamba_head_dim`.
126+
127+
**Advantages:**
128+
129+
- Better accuracy than depth-pruned models at a fixed number of parameters
130+
131+
**Recommendations:**
132+
133+
- Start with pruning `hidden_size` and `ffn_hidden_size` as the simplest configuration
134+
- Up to **1/3rd parameter reduction** can generally result in a model above the Pareto frontier with good latency-accuracy trade-off (when using a good quality dataset for distillation with ~80-100B tokens)
135+
- **Axis sensitivity:** MLP dimensions (`ffn_hidden_size`) can typically be pruned more aggressively than embedding dimensions (`hidden_size`) and attention/Mamba dimensions (`num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`)
136+
- For pruning **>50%**, use iterative pruning: compress by 30%, perform distillation, then compress again
137+
138+
**Examples:**
139+
140+
- [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) (`ffn_hidden_size=12288`, `hidden_size=4096`) → 6B (`ffn_hidden_size=9216`, `hidden_size=3584`)
141+
- [Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) (`ffn_hidden_size=14336`, `hidden_size=4096`) → 4.5B (`ffn_hidden_size=9216`, `hidden_size=3072`)
142+
- [Nemotron-H-8B-Base-8K](https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K) (`ffn_hidden_size=21504`, `hidden_size=4096`, `mamba_num_heads=128`) → [Nemotron-H-4B-Base-8K](https://huggingface.co/nvidia/Nemotron-H-4B-Base-8K) (`ffn_hidden_size=12288`, `hidden_size=3072`, `mamba_num_heads=112`) - See [paper](https://arxiv.org/pdf/2504.11409)
143+
144+
#### Depth and Width Pruning
145+
146+
For optimal results, combine depth and width pruning. This will require more tuning to find the best architecture.
147+
148+
**Examples:**
149+
150+
- [NVIDIA-Nemotron-Nano-12B-v2](https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2) (`ffn_hidden_size=20480`, `hidden_size=5120`, `num_layers=62`) → [NVIDIA-Nemotron-Nano-9B-v2](https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-9B-v2) (`ffn_hidden_size=15680`, `hidden_size=4480`, `num_layers=56`) - See [paper](https://arxiv.org/pdf/2508.14444)
151+
152+
#### General Pruning Guidelines
153+
154+
- **Pruning ratio:** Anything **>50% pruning is hard to recover**. For such aggressive pruning, iterative pruning (compress → distill → compress again) is recommended.
155+
- **Latency-accuracy trade-off:** The more pruning you do, the faster your model will be at the cost of lower accuracy. Choose based on your requirements.
156+
- **Dataset quality:** Use a high-quality dataset for distillation. If you don't have a specific dataset, [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) is recommended.
157+
- **Post-training:** Further post-training (e.g., instruction tuning, preference alignment) is needed after pruning and distillation on pre-training datasets to improve reasoning capabilities. A good dataset for post-training is [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2).
158+
159+
#### Distillation Hyperparameters
160+
161+
After pruning, distillation is required to recover model accuracy. Below are recommended starting hyperparameters for distillation:
162+
163+
| **Hyperparameter** | **Recommendation** |
164+
| :---: | :---: |
165+
| **Sequence Length** | 8192 (or 4096 if dataset has smaller sequences) |
166+
| **Global Batch Size (GBS)** | 768 |
167+
| **Micro Batch Size (MBS)** | As large as your GPU memory can accommodate |
168+
| **Learning Rate (LR)** | 1e-4 → 1e-5 (linear decay) for 30-50% pruning<br>• More compression → higher LR<br>• Less compression → lower LR<br>• As model gets larger → reduce LR to avoid divergence |
169+
| **Warmup Steps** | 100 |
170+
| **Training Max Steps** | Num training tokens / (Seq len × GBS)<br>• Recommended: 80-100B tokens |
171+
| **Data Composition** | • Standard models: 100% pre-training data<br>• Reasoning models: 70% reasoning data + 30% pre-training data |
172+
173+
> [!TIP]
174+
> If you know the maximum learning rate used during the original training, a good rule of thumb for knowledge distillation is to use **1/5th of that maximum LR** when compressing by ~50%.
175+
96176
## Examples
97177

98178
### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Qwen 3, Nemotron Nano)
@@ -108,10 +188,12 @@ Some of the models pruned using Minitron method followed by distillation and pos
108188

109189
### FastNAS Pruning for PyTorch Computer Vision Models
110190

111-
Checkout the FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory
191+
Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search).
192+
193+
You can also take a look at FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory
112194
which showcases the usage of FastNAS for pruning a ResNet 20 model for the CIFAR-10 dataset. The notebook
113-
also how to profiling the model to understand the search space of possible pruning options and demonstrates
114-
the usage saving and restoring pruned models.
195+
also shows how to profile the model to understand the search space of possible pruning options and demonstrates
196+
how to save and restore pruned models.
115197

116198
### GradNAS Pruning for HuggingFace Language Models (e.g. BERT)
117199

0 commit comments

Comments
 (0)