-
Notifications
You must be signed in to change notification settings - Fork 190
Add minitron pruning and distillation guidelines in pruning readme #419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughDocumentation-only changes: changelog dates and ordering adjusted; new Documentation entry added linking to pruning README; extensive expansion of pruning/distillation guidance in the pruning example README; formatting and reflow edits only—no code or API changes. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
CHANGELOG.rst(2 hunks)examples/pruning/README.md(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #419 +/- ##
=======================================
Coverage 73.36% 73.36%
=======================================
Files 180 180
Lines 17919 17919
=======================================
+ Hits 13146 13147 +1
+ Misses 4773 4772 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Keval Morabia <[email protected]>
bc228d2 to
698ed38
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
CHANGELOG.rst(2 hunks)examples/pruning/README.md(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- CHANGELOG.rst
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM;
Signed-off-by: Keval Morabia <[email protected]>
2671fe0 to
0da4caa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (3)
examples/pruning/README.md (3)
115-116: Tighten Pareto phrasing and clarify iterative pruning.“above the Pareto frontier” is imprecise; clarify “on or near,” and specify the per‑iteration compression amount.
- - Up to **1/3rd parameter reduction** can generally result in a model above the Pareto frontier with good latency-accuracy trade-off (when using a good quality dataset for distillation with ~80-100B tokens) + - Up to **1/3rd parameter reduction** can generally result in a model on or near the Pareto frontier with a good latency‑accuracy trade‑off (when using a high‑quality distillation dataset with ~80–100B tokens) - - For pruning **>50%**, use iterative pruning: compress by 30%, perform distillation, then compress again + - For pruning **>50%**, use iterative pruning: compress by ~30% per iteration (≈0.7× parameters), distill, then optionally repeat
133-137: Add divisibility/shape constraints to avoid invalid configs.Call out common width‑pruning constraints (GQA, heads, head_dim) to prevent shape errors.
- **Axis sensitivity:** MLP dimensions (`ffn_hidden_size`) can typically be pruned more aggressively than embedding dimensions (`hidden_size`) and attention/Mamba dimensions (`num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`) + - Maintain divisibility/shape constraints: ensure `hidden_size % num_attention_heads == 0`; `num_attention_heads % num_query_groups == 0` (for GQA); and, where applicable, `mamba_num_heads * mamba_head_dim == hidden_size`. Adjust linked axes together to keep tensor shapes valid. - For pruning **>50%**, use iterative pruning: compress by 30%, perform distillation, then compress again
163-172: Round out distillation knobs (KD T and loss mixing).Add KD temperature and KL/CE loss‑mix guidance to make the recipe actionable.
| **Learning Rate (LR)** | 1e-4 → 1e-5 (linear decay) for 30-50% pruning<br>• More compression → higher LR<br>• Less compression → lower LR<br>• As model gets larger → reduce LR to avoid divergence | +| **KD Temperature (T)** | 1.0–2.0 (use higher T for larger compression; tune per dataset) | +| **KD loss mixing (alpha)** | 0.5–0.8 weight on KL divergence vs. 0.2–0.5 on CE to labels | | **Warmup Steps** | 100 |
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/pruning/README.md(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
examples/pruning/README.md (2)
20-22: TOC updates look good.Anchors resolve to existing sections. No issues.
140-143: Confirm model specs and citations
- Qwen3-8B defaults to 36 layers (not 24) and Llama-3.1-8B to 32 layers (not 16); the “36→24” and “32→16” numbers are pruning depth targets, not official defaults.
- Nemotron-H-8B’s hidden_size=4096, ffn_hidden_size=21504 and Mamba state_dim/groups match the Nemotron-H paper; arXiv:2504.11409 discusses compression experiments on that base model.
- Nemotron-Nano-9B-v2 (56L, hidden_size=4480, ffn_hidden_size=15680, mamba_num_heads=128) and 12B-v2 (62L, hidden_size=5120, ffn_hidden_size=20480, mamba_num_heads=128) are documented in NeMo/Megatron-bridge and described in arXiv:2508.14444.
| ## Pruning Guidelines | ||
|
|
||
| ### Minitron | ||
|
|
||
| This section provides recommendations for choosing pruning strategies and distillation hyperparameters for Minitron pruning to help achieve the best latency-accuracy trade-offs. | ||
|
|
||
| #### Depth Pruning | ||
|
|
||
| Depth pruning reduces the number of layers (`num_layers`) in the model. | ||
|
|
||
| **Advantages:** | ||
|
|
||
| - Simpler to configure - only 1 parameter to tune | ||
| - Faster inference than width-pruned models at a fixed number of parameters | ||
|
|
||
| **Recommendations:** | ||
|
|
||
| - Up to **1/3rd parameter reduction** can generally result in a model above the Pareto frontier with good latency-accuracy trade-off (when using a good quality dataset for distillation with ~80-100B tokens) | ||
| - For pruning **>50%**, use iterative pruning: compress by 30%, perform distillation, then compress again | ||
|
|
||
| **Examples:** | ||
|
|
||
| - [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) (`num_layers=36`) → 6B (`num_layers=24`) | ||
| - [Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) (`num_layers=32`) → 4.5B (`num_layers=16`) | ||
|
|
||
| #### Width Pruning | ||
|
|
||
| Width pruning reduces model dimensions per layer such as `hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, and `mamba_head_dim`. | ||
|
|
||
| **Advantages:** | ||
|
|
||
| - Better accuracy than depth-pruned models at a fixed number of parameters | ||
|
|
||
| **Recommendations:** | ||
|
|
||
| - Start with pruning `hidden_size` and `ffn_hidden_size` as the simplest configuration | ||
| - Up to **1/3rd parameter reduction** can generally result in a model above the Pareto frontier with good latency-accuracy trade-off (when using a good quality dataset for distillation with ~80-100B tokens) | ||
| - **Axis sensitivity:** MLP dimensions (`ffn_hidden_size`) can typically be pruned more aggressively than embedding dimensions (`hidden_size`) and attention/Mamba dimensions (`num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`) | ||
| - For pruning **>50%**, use iterative pruning: compress by 30%, perform distillation, then compress again | ||
|
|
||
| **Examples:** | ||
|
|
||
| - [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) (`ffn_hidden_size=12288`, `hidden_size=4096`) → 6B (`ffn_hidden_size=9216`, `hidden_size=3584`) | ||
| - [Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) (`ffn_hidden_size=14336`, `hidden_size=4096`) → 4.5B (`ffn_hidden_size=9216`, `hidden_size=3072`) | ||
| - [Nemotron-H-8B-Base-8K](https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K) (`ffn_hidden_size=21504`, `hidden_size=4096`, `mamba_num_heads=128`) → [Nemotron-H-4B-Base-8K](https://huggingface.co/nvidia/Nemotron-H-4B-Base-8K) (`ffn_hidden_size=12288`, `hidden_size=3072`, `mamba_num_heads=112`) - See [paper](https://arxiv.org/pdf/2504.11409) | ||
|
|
||
| #### Depth and Width Pruning | ||
|
|
||
| For optimal results, combine depth and width pruning. This will require more tuning to find the best architecture. | ||
|
|
||
| **Examples:** | ||
|
|
||
| - [NVIDIA-Nemotron-Nano-12B-v2](https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2) (`ffn_hidden_size=20480`, `hidden_size=5120`, `num_layers=62`) → [NVIDIA-Nemotron-Nano-9B-v2](https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-9B-v2) (`ffn_hidden_size=15680`, `hidden_size=4480`, `num_layers=56`) - See [paper](https://arxiv.org/pdf/2508.14444) | ||
|
|
||
| #### General Pruning Guidelines | ||
|
|
||
| - **Pruning ratio:** Anything **>50% pruning is hard to recover**. For such aggressive pruning, iterative pruning (compress → distill → compress again) is recommended. | ||
| - **Latency-accuracy trade-off:** The more pruning you do, the faster your model will be at the cost of lower accuracy. Choose based on your requirements. | ||
| - **Dataset quality:** Use a high-quality dataset for distillation. If you don't have a specific dataset, [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) is recommended. | ||
| - **Post-training:** Further post-training (e.g., instruction tuning, preference alignment) is needed after pruning and distillation on pre-training datasets to improve reasoning capabilities. A good dataset for post-training is [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). | ||
|
|
||
| #### Distillation Hyperparameters | ||
|
|
||
| After pruning, distillation is required to recover model accuracy. Below are recommended starting hyperparameters for distillation: | ||
|
|
||
| | **Hyperparameter** | **Recommendation** | | ||
| | :---: | :---: | | ||
| | **Sequence Length** | 8192 (or 4096 if dataset has smaller sequences) | | ||
| | **Global Batch Size (GBS)** | 768 | | ||
| | **Micro Batch Size (MBS)** | As large as your GPU memory can accommodate | | ||
| | **Learning Rate (LR)** | 1e-4 → 1e-5 (linear decay) for 30-50% pruning<br>• More compression → higher LR<br>• Less compression → lower LR<br>• As model gets larger → reduce LR to avoid divergence | | ||
| | **Warmup Steps** | 100 | | ||
| | **Training Max Steps** | Num training tokens / (Seq len × GBS)<br>• Recommended: 80-100B tokens | | ||
| | **Data Composition** | • Standard models: 100% pre-training data<br>• Reasoning models: 70% reasoning data + 30% pre-training data | | ||
|
|
||
| > [!TIP] | ||
| > If you know the maximum learning rate used during the original training, a good rule of thumb for knowledge distillation is to use **1/5th of that maximum LR** when compressing by ~50%. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Sanity‑check external links and local notebook path.
Run this to quickly validate URLs (HTTP 200/301) and ensure the notebook exists.
🏁 Script executed:
#!/usr/bin/env bash
set -euo pipefail
FILE="examples/pruning/README.md"
echo "Checking external links..."
# Extract and de-duplicate http(s) links
mapfile -t LINKS < <(grep -oE '\(https?://[^) ]+\)' -n "$FILE" | sed -E 's/.*\((https?:\/\/[^) ]+)\).*/\1/' | sort -u)
fail=0
for u in "${LINKS[@]}"; do
code=$(curl -m 10 -sS -o /dev/null -w '%{http_code}' -L "$u" || echo "000")
printf '%-3s %s\n' "$code" "$u"
[[ "$code" =~ ^2|3 ]] || fail=1
done
echo
echo "Checking local notebook path..."
if [[ -f "examples/pruning/cifar_resnet.ipynb" ]]; then
echo "OK examples/pruning/cifar_resnet.ipynb exists"
else
echo "ERR examples/pruning/cifar_resnet.ipynb missing"
fail=1
fi
exit $failLength of output: 83
🏁 Script executed:
#!/usr/bin/env bash
set -euo pipefail
FILE="examples/pruning/README.md"
echo "Checking external links..."
fail=0
grep -oE '\(https?://[^) ]+\)' "$FILE" \
| sed -E 's/.*\((https?:\/\/[^) ]+)\).*/\1/' \
| sort -u \
| while read -r url; do
code=$(curl -m 10 -sS -o /dev/null -w '%{http_code}' -L "$url" || echo "000")
printf '%-3s %s\n' "$code" "$url"
if ! [[ "$code" =~ ^[23] ]]; then
fail=1
fi
done
echo
echo "Checking local notebook path..."
if [[ -f "examples/pruning/cifar_resnet.ipynb" ]]; then
echo "OK examples/pruning/cifar_resnet.ipynb exists"
else
echo "ERR examples/pruning/cifar_resnet.ipynb missing"
fail=1
fi
exit $failLength of output: 1529
Broken external link: the URL for Nemotron-H-4B-Base-8K (https://huggingface.co/nvidia/Nemotron-H-4B-Base-8K) returns HTTP 401; update it to a publicly accessible link or ensure the model page is published.
🤖 Prompt for AI Agents
In examples/pruning/README.md around lines 98 to 175 the reference link for
Nemotron-H-4B-Base-8K (https://huggingface.co/nvidia/Nemotron-H-4B-Base-8K)
returns HTTP 401; update the link to a publicly accessible URL or replace it
with the correct published Hugging Face model/dataset page or an alternative
public resource, and verify the new URL returns 200; also ensure any adjacent
links referencing the same private artifact are updated consistently and the
link text remains accurate.
| Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search). | ||
|
|
||
| You can also take a look at FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory | ||
| which showcases the usage of FastNAS for pruning a ResNet 20 model for the CIFAR-10 dataset. The notebook | ||
| also how to profiling the model to understand the search space of possible pruning options and demonstrates | ||
| the usage saving and restoring pruned models. | ||
| also shows how to profile the model to understand the search space of possible pruning options and demonstrates | ||
| how to save and restore pruned models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Polish FastNAS paragraph (articles, hyphenation, flow).
Minor grammar/style fixes.
-Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search).
+Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search).
-You can also take a look at FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory
-which showcases the usage of FastNAS for pruning a ResNet 20 model for the CIFAR-10 dataset. The notebook
-also shows how to profile the model to understand the search space of possible pruning options and demonstrates
-how to save and restore pruned models.
+You can also take a look at the FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory,
+which shows how to use FastNAS to prune a ResNet‑20 model on the CIFAR‑10 dataset. The notebook
+also shows how to profile the model to understand the search space of possible pruning options and
+how to save and restore pruned models.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search). | |
| You can also take a look at FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory | |
| which showcases the usage of FastNAS for pruning a ResNet 20 model for the CIFAR-10 dataset. The notebook | |
| also how to profiling the model to understand the search space of possible pruning options and demonstrates | |
| the usage saving and restoring pruned models. | |
| also shows how to profile the model to understand the search space of possible pruning options and demonstrates | |
| how to save and restore pruned models. | |
| Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search). | |
| You can also take a look at the FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory, | |
| which shows how to use FastNAS to prune a ResNet-20 model on the CIFAR-10 dataset. The notebook | |
| also shows how to profile the model to understand the search space of possible pruning options and | |
| how to save and restore pruned models. |
🤖 Prompt for AI Agents
In examples/pruning/README.md around lines 191 to 196, the FastNAS paragraph
needs minor grammar and style polishing: add definite/indefinite articles where
appropriate, hyphenate compound adjectives (e.g., "interactive notebook" is fine
but "FastNAS pruning" could be "the FastNAS pruning"), improve flow by combining
sentences and clarifying references, and ensure consistent punctuation. Edit the
text to read smoothly (e.g., reference the documentation link, refer to "the
FastNAS pruning interactive notebook cifar_resnet.ipynb in this directory,"
mention "ResNet-20" with hyphenation, and use "CIFAR-10 dataset"), and ensure
the final sentences clearly state that the notebook profiles the model to
explore pruning options and demonstrates saving and restoring pruned models.
) Signed-off-by: Keval Morabia <[email protected]> Signed-off-by: Hrishith Thadicherla <[email protected]>
What does this PR do?
Type of change: Documentation update
Summary by CodeRabbit