Skip to content

Commit 2214a3d

Browse files
DynamicModuleList, moe_ffn hparam, Tests, Cleanup
Signed-off-by: Keval Morabia <[email protected]>
1 parent 3d246d7 commit 2214a3d

File tree

15 files changed

+662
-305
lines changed

15 files changed

+662
-305
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ Model Optimizer Changelog (Linux)
88

99
- Fix a bug in FastNAS pruning (computer vision models) where the model parameters were sorted twice messing up the ordering.
1010

11+
**New Features**
12+
13+
- Add MoE (e.g. Qwen3-30B-A3B) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``).
14+
1115
0.39 (2025-11-14)
1216
^^^^^^^^^^^^^^^^^
1317

docs/source/guides/7_nas.rst

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,12 @@ can be converted into searchable units:
361361
# search over the number of layers (depth) in the sequential layer.
362362
nn.Sequential
363363
364-
# We convert Megatron-core / NeMo GPT or Mamba style models (e.g. Llama3.1, NeMo Mistral, NeMotron-H, etc.)
365-
# to automatically search over the MLP hidden size, number of attention heads, number of GQA groups,
366-
# number of mamba heads, mamba head dimension, and depth of the model.
364+
# We convert Megatron-core / NeMo GPT or MoE or Mamba Hybrid style models (e.g. Llama3, Nemotron-H, Qwen3-30B-A3B)
365+
# to automatically search over the
366+
# MLP hidden size, number of attention heads, number of GQA groups,
367+
# number of mamba heads, mamba head dimension,
368+
# number of moe experts, moe ffn hidden size, moe shared expert intermediate size,
369+
# and depth of the model.
367370
megatron.core.models.gpt.GPTModel
368371
megatron.core.models.mamba.MambaModel
369372
nemo.collections.llm.gpt.model.base.GPTModel
@@ -640,7 +643,7 @@ The difference between NAS and pruning is summarized below.
640643
[Advanced] Adding a new NAS/Prune Algorithm
641644
===========================================
642645

643-
* Please refer to this `template <https://github.com/NVIDIA/TensorRT-Model-Optimizer/compare/template/new-nas-mode>`_
646+
* Please refer to this `template <https://github.com/NVIDIA/TensorRT-Model-Optimizer/compare/template/new-nas-mode>`_
644647
for adding a new NAS algorithm.
645648
* Please refer to `mcore_minitron.py <https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/prune/plugins/mcore_minitron.py>`_
646649
for an actual example of adding Minitron Pruning algorithm.

examples/megatron-lm/README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
| Model | Quantization | EAGLE3 | Q-LoRA | Pruning (PP only) | Distillation |
2121
| :---: | :---: | :---: | :---: | :---: | :---: |
2222
| `moonshotai/Kimi-K2-Instruct` || **Online** | | | |
23-
| `Qwen/Qwen3-{30B-A3B, 235B-A22B}` | **WAR** | **Online** | | | |
23+
| `Qwen/Qwen3-{30B-A3B, 235B-A22B}` | **WAR** | **Online** | | | |
2424
| `Qwen/Qwen3-{0.6B, 8B}` || **Online** | |||
2525
| `deepseek-ai/DeepSeek-R1` || **Online** | | | |
2626
| `meta-llama/Llama-{3.1-8B, 3.1-405B, 3.2-1B}-Instruct` || **Online** | |||
@@ -112,14 +112,17 @@ Coming soon ...
112112

113113
Checkout pruning [getting started section](../pruning/README.md#getting-started) and [guidelines](../pruning/README.md#pruning-guidelines) for configuring pruning parameters in the pruning README.
114114

115-
Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Available pruning options are:
115+
Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Available pruning dimensions are:
116116

117117
- `TARGET_FFN_HIDDEN_SIZE`
118118
- `TARGET_HIDDEN_SIZE`
119119
- `TARGET_NUM_ATTENTION_HEADS`
120120
- `TARGET_NUM_QUERY_GROUPS`
121121
- `TARGET_MAMBA_NUM_HEADS`
122122
- `TARGET_MAMBA_HEAD_DIM`
123+
- `TARGET_NUM_MOE_EXPERTS`
124+
- `TARGET_MOE_FFN_HIDDEN_SIZE`
125+
- `TARGET_MOE_SHARED_EXPERT_INTERMEDIATE_SIZE`
123126
- `TARGET_NUM_LAYERS`
124127
- `LAYERS_TO_DROP` (comma separated, 1-indexed list of layer numbers to directly drop)
125128

examples/pruning/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Pruning can involve removal (prune) of Linear and Conv layers, and Transformer a
66

77
This section focuses on applying Model Optimizer's state-of-the-art complementary pruning modes to enable you to search for the best subnet architecture from your provided base model:
88

9-
1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT, Mamba and Hybrid Transformer Mamba models in NVIDIA NeMo or Megatron-LM framework. It uses the activation magnitudes to prune the embedding hidden size, mlp ffn hidden size, transformer attention heads, GQA query groups, mamba heads and head dimension, and number of layers of the model.
9+
1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT, Mamba and Hybrid Transformer Mamba models in NVIDIA NeMo or Megatron-LM framework. It uses the activation magnitudes to prune the embedding hidden size; mlp ffn hidden size; transformer attention heads and GQA query groups; mamba heads and head dimension; MoE number of experts, ffn hidden size, and shared expert intermediate size; and number of layers of the model.
1010
1. FastNAS: A pruning method recommended for Computer Vision models. Given a pretrained model, FastNAS finds the subnet which maximizes the score function while meeting the given constraints.
1111
1. GradNAS: A light-weight pruning method recommended for language models like Hugging Face BERT, GPT-J. It uses the gradient information to prune the model's linear layers and attention heads to meet the given constraints.
1212

@@ -89,11 +89,11 @@ If your model parameters are already sorted, you can skip the sorting step by se
8989

9090
| **Algorithm** | **Model** | **Pruning Constraints** |
9191
| :---: | :---: | :---: |
92-
| Minitron | Megatron-core / NeMo based GPT / Mamba / Hybrid Models<sup>1</sup> | Export config with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`) and/or depth (`num_layers`) values |
92+
| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid Models<sup>1</sup> | Export config with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) values |
9393
| FastNAS | Computer Vision models | flops, parameters |
9494
| GradNAS | HuggingFace BERT, GPT-J | flops, parameters |
9595

96-
> *<sup>1.</sup>Only Pipeline Parallel models are supported. Hugging Face models can be converted to NeMo format and used subsequently.*
96+
> *<sup>1.</sup>Only Pipeline Parallel models are supported. Hugging Face models can be converted to Megatron-LM/NeMo format and used subsequently.*
9797
9898
## Pruning Guidelines
9999

@@ -122,7 +122,7 @@ Depth pruning reduces the number of layers (`num_layers`) in the model.
122122

123123
#### Width Pruning
124124

125-
Width pruning reduces model dimensions per layer such as `hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, and `mamba_head_dim`.
125+
Width pruning reduces model dimensions per layer such as `hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, and `moe_shared_expert_intermediate_size`.
126126

127127
**Advantages:**
128128

modelopt/torch/nas/modules/container.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..registry import DMRegistry
2727
from ..traced_hp import TracedHp
2828

29-
__all__ = ["_DynamicSequential"]
29+
__all__ = ["DynamicModuleList", "_DynamicSequential"]
3030

3131

3232
def _activate_depth(func: Callable) -> Callable:
@@ -97,3 +97,35 @@ def modify(self, *, min_depth: int = 0):
9797
"""
9898
hp = self.get_hparam("depth")
9999
hp.choices = [d for d in hp.choices if d >= min_depth]
100+
101+
102+
# NOTE: We provide a parent class since we do not register to DMRegistry and explicitly convert a module if needed.
103+
class DynamicModuleList(DynamicModule, nn.ModuleList):
104+
"""An ``nn.ModuleList`` container with dynamic hyperparams and variable ``depth``.
105+
106+
Unlike _DynamicSequential, this module supports sorting/reordering of modules based on
107+
importance in addition to variable depth.
108+
"""
109+
110+
def _setup(self):
111+
# register hyperparameters
112+
self._register_hparam("depth", TracedHp(list(range(1, len(self) + 1))))
113+
114+
# register _modules as a dynamic attribute
115+
self._register_dynamic_attribute("_modules", self._get_modules)
116+
117+
@staticmethod
118+
def _get_modules(mod: "DynamicModuleList", modules: dict) -> dict:
119+
"""Get modules with dynamic depth and ordering applied based on active_slice."""
120+
hp = mod.get_hparam("depth")
121+
active_slice = hp.active_slice
122+
123+
items = list(modules.items())
124+
125+
if isinstance(active_slice, slice):
126+
active_items = items[active_slice]
127+
else:
128+
active_items = [items[idx] for idx in active_slice.tolist()]
129+
130+
# Re-create dict with keys as str(index) from 0 to len(active_items)
131+
return {str(i): module for i, (_, module) in enumerate(active_items)}

0 commit comments

Comments
 (0)