- 
                Notifications
    
You must be signed in to change notification settings  - Fork 190
 
Add minitron pruning and distillation guidelines in pruning readme #419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 
          
            
          
           | 
    @@ -17,6 +17,8 @@ This section focuses on applying Model Optimizer's state-of-the-art complementar | |||||||||||||||||||||||||||||
| | Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | | | ||||||||||||||||||||||||||||||
| | Getting Started | Learn how to use the pruning API | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html)\] | | ||||||||||||||||||||||||||||||
| | Support Matrix | View the support matrix to see available pruning algorithms and their compatibility with different models and frameworks | \[[Link](#support-matrix)\] | | | ||||||||||||||||||||||||||||||
| | Pruning Guidelines | Guidelines for choosing how and how much to prune for best results | \[[Link](#pruning-guidelines)\] | | | ||||||||||||||||||||||||||||||
| | Examples | Examples of different pruning methods | \[[Link](#examples)\] | | | ||||||||||||||||||||||||||||||
| | Resources | Extra links to relevant resources | \[[Link](#resources)\] | | | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| </div> | ||||||||||||||||||||||||||||||
| 
          
            
          
           | 
    @@ -93,6 +95,84 @@ If your model parameters are already sorted, you can skip the sorting step by se | |||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| > *<sup>1.</sup>Only Pipeline Parallel models are supported. Hugging Face models can be converted to NeMo format and used subsequently.* | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| ## Pruning Guidelines | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| ### Minitron | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| This section provides recommendations for choosing pruning strategies and distillation hyperparameters for Minitron pruning to help achieve the best latency-accuracy trade-offs. | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| #### Depth Pruning | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| Depth pruning reduces the number of layers (`num_layers`) in the model. | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| **Advantages:** | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| - Simpler to configure - only 1 parameter to tune | ||||||||||||||||||||||||||||||
| - Faster inference than width-pruned models at a fixed number of parameters | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| **Recommendations:** | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| - Up to **1/3rd parameter reduction** can generally result in a model above the Pareto frontier with good latency-accuracy trade-off (when using a good quality dataset for distillation with ~80-100B tokens) | ||||||||||||||||||||||||||||||
| - For pruning **>50%**, use iterative pruning: compress by 30%, perform distillation, then compress again | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| **Examples:** | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| - [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) (`num_layers=36`) → 6B (`num_layers=24`) | ||||||||||||||||||||||||||||||
| - [Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) (`num_layers=32`) → 4.5B (`num_layers=16`) | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| #### Width Pruning | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| Width pruning reduces model dimensions per layer such as `hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, and `mamba_head_dim`. | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| **Advantages:** | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| - Better accuracy than depth-pruned models at a fixed number of parameters | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| **Recommendations:** | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| - Start with pruning `hidden_size` and `ffn_hidden_size` as the simplest configuration | ||||||||||||||||||||||||||||||
| - Up to **1/3rd parameter reduction** can generally result in a model above the Pareto frontier with good latency-accuracy trade-off (when using a good quality dataset for distillation with ~80-100B tokens) | ||||||||||||||||||||||||||||||
| - **Axis sensitivity:** MLP dimensions (`ffn_hidden_size`) can typically be pruned more aggressively than embedding dimensions (`hidden_size`) and attention/Mamba dimensions (`num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`) | ||||||||||||||||||||||||||||||
| - For pruning **>50%**, use iterative pruning: compress by 30%, perform distillation, then compress again | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| **Examples:** | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| - [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) (`ffn_hidden_size=12288`, `hidden_size=4096`) → 6B (`ffn_hidden_size=9216`, `hidden_size=3584`) | ||||||||||||||||||||||||||||||
| - [Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) (`ffn_hidden_size=14336`, `hidden_size=4096`) → 4.5B (`ffn_hidden_size=9216`, `hidden_size=3072`) | ||||||||||||||||||||||||||||||
| - [Nemotron-H-8B-Base-8K](https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K) (`ffn_hidden_size=21504`, `hidden_size=4096`, `mamba_num_heads=128`) → [Nemotron-H-4B-Base-8K](https://huggingface.co/nvidia/Nemotron-H-4B-Base-8K) (`ffn_hidden_size=12288`, `hidden_size=3072`, `mamba_num_heads=112`) - See [paper](https://arxiv.org/pdf/2504.11409) | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| #### Depth and Width Pruning | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| For optimal results, combine depth and width pruning. This will require more tuning to find the best architecture. | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| **Examples:** | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| - [NVIDIA-Nemotron-Nano-12B-v2](https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2) (`ffn_hidden_size=20480`, `hidden_size=5120`, `num_layers=62`) → [NVIDIA-Nemotron-Nano-9B-v2](https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-9B-v2) (`ffn_hidden_size=15680`, `hidden_size=4480`, `num_layers=56`) - See [paper](https://arxiv.org/pdf/2508.14444) | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| #### General Pruning Guidelines | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| - **Pruning ratio:** Anything **>50% pruning is hard to recover**. For such aggressive pruning, iterative pruning (compress → distill → compress again) is recommended. | ||||||||||||||||||||||||||||||
| - **Latency-accuracy trade-off:** The more pruning you do, the faster your model will be at the cost of lower accuracy. Choose based on your requirements. | ||||||||||||||||||||||||||||||
| - **Dataset quality:** Use a high-quality dataset for distillation. If you don't have a specific dataset, [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) is recommended. | ||||||||||||||||||||||||||||||
| - **Post-training:** Further post-training (e.g., instruction tuning, preference alignment) is needed after pruning and distillation on pre-training datasets to improve reasoning capabilities. A good dataset for post-training is [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| #### Distillation Hyperparameters | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| After pruning, distillation is required to recover model accuracy. Below are recommended starting hyperparameters for distillation: | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| | **Hyperparameter** | **Recommendation** | | ||||||||||||||||||||||||||||||
| | :---: | :---: | | ||||||||||||||||||||||||||||||
| | **Sequence Length** | 8192 (or 4096 if dataset has smaller sequences) | | ||||||||||||||||||||||||||||||
| | **Global Batch Size (GBS)** | 768 | | ||||||||||||||||||||||||||||||
| | **Micro Batch Size (MBS)** | As large as your GPU memory can accommodate | | ||||||||||||||||||||||||||||||
| | **Learning Rate (LR)** | 1e-4 → 1e-5 (linear decay) for 30-50% pruning<br>• More compression → higher LR<br>• Less compression → lower LR<br>• As model gets larger → reduce LR to avoid divergence | | ||||||||||||||||||||||||||||||
| | **Warmup Steps** | 100 | | ||||||||||||||||||||||||||||||
| | **Training Max Steps** | Num training tokens / (Seq len × GBS)<br>• Recommended: 80-100B tokens | | ||||||||||||||||||||||||||||||
| | **Data Composition** | • Standard models: 100% pre-training data<br>• Reasoning models: 70% reasoning data + 30% pre-training data | | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| > [!TIP] | ||||||||||||||||||||||||||||||
| > If you know the maximum learning rate used during the original training, a good rule of thumb for knowledge distillation is to use **1/5th of that maximum LR** when compressing by ~50%. | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| ## Examples | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| ### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Qwen 3, Nemotron Nano) | ||||||||||||||||||||||||||||||
| 
        
          
        
         | 
    @@ -108,10 +188,12 @@ Some of the models pruned using Minitron method followed by distillation and pos | |||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| ### FastNAS Pruning for PyTorch Computer Vision Models | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| Checkout the FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory | ||||||||||||||||||||||||||||||
| Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search). | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| You can also take a look at FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory | ||||||||||||||||||||||||||||||
| which showcases the usage of FastNAS for pruning a ResNet 20 model for the CIFAR-10 dataset. The notebook | ||||||||||||||||||||||||||||||
| also how to profiling the model to understand the search space of possible pruning options and demonstrates | ||||||||||||||||||||||||||||||
| the usage saving and restoring pruned models. | ||||||||||||||||||||||||||||||
| also shows how to profile the model to understand the search space of possible pruning options and demonstrates | ||||||||||||||||||||||||||||||
| how to save and restore pruned models. | ||||||||||||||||||||||||||||||
| 
         
      Comment on lines
    
      +191
     to 
      +196
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Polish FastNAS paragraph (articles, hyphenation, flow). Minor grammar/style fixes. -Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search).
+Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search).
-You can also take a look at FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory
-which showcases the usage of FastNAS for pruning a ResNet 20 model for the CIFAR-10 dataset. The notebook
-also shows how to profile the model to understand the search space of possible pruning options and demonstrates
-how to save and restore pruned models.
+You can also take a look at the FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory,
+which shows how to use FastNAS to prune a ResNet‑20 model on the CIFAR‑10 dataset. The notebook
+also shows how to profile the model to understand the search space of possible pruning options and
+how to save and restore pruned models.📝 Committable suggestion
 
        Suggested change
       
    
 🤖 Prompt for AI Agents | 
||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| ### GradNAS Pruning for HuggingFace Language Models (e.g. BERT) | ||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||
| 
          
            
          
           | 
    ||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Sanity‑check external links and local notebook path.
Run this to quickly validate URLs (HTTP 200/301) and ensure the notebook exists.
🏁 Script executed:
Length of output: 83
🏁 Script executed:
Length of output: 1529
Broken external link: the URL for Nemotron-H-4B-Base-8K (https://huggingface.co/nvidia/Nemotron-H-4B-Base-8K) returns HTTP 401; update it to a publicly accessible link or ensure the model page is published.
🤖 Prompt for AI Agents