Skip to content

Commit 050e1a5

Browse files
Add GPT MoE pruning unit test
Signed-off-by: Keval Morabia <[email protected]>
1 parent dec5105 commit 050e1a5

File tree

5 files changed

+129
-10
lines changed

5 files changed

+129
-10
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Model Optimizer Changelog (Linux)
66

77
**New Features**
88

9-
- Add MoE pruning support for ``num_moe_experts`` and ``moe_shared_expert_intermediate_size`` in Minitron pruning (``mcore_minitron``).
9+
- Add MoE (e.g. Qwen3-30B-A3B) pruning support for ``num_moe_experts`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``).
1010

1111
0.39 (2025-11-07)
1212
^^^^^^^^^^^^^^^^^

examples/megatron-lm/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
| Model | Quantization | EAGLE3 | Q-LoRA | Pruning (PP only) | Distillation |
2121
| :---: | :---: | :---: | :---: | :---: | :---: |
2222
| `moonshotai/Kimi-K2-Instruct` || **Online** | | | |
23-
| `Qwen/Qwen3-{30B-A3B, 235B-A22B}` | **WAR** | **Online** | | | |
23+
| `Qwen/Qwen3-{30B-A3B, 235B-A22B}` | **WAR** | **Online** | | | |
2424
| `Qwen/Qwen3-{0.6B, 8B}` || **Online** | |||
2525
| `deepseek-ai/DeepSeek-R1` || **Online** | | | |
2626
| `meta-llama/Llama-{3.1-8B, 3.1-405B, 3.2-1B}-Instruct` || **Online** | |||
@@ -112,14 +112,16 @@ Coming soon ...
112112

113113
Checkout pruning [getting started section](../pruning/README.md#getting-started) and [guidelines](../pruning/README.md#pruning-guidelines) for configuring pruning parameters in the pruning README.
114114

115-
Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Available pruning options are:
115+
Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Available pruning dimensions are:
116116

117117
- `TARGET_FFN_HIDDEN_SIZE`
118118
- `TARGET_HIDDEN_SIZE`
119119
- `TARGET_NUM_ATTENTION_HEADS`
120120
- `TARGET_NUM_QUERY_GROUPS`
121121
- `TARGET_MAMBA_NUM_HEADS`
122122
- `TARGET_MAMBA_HEAD_DIM`
123+
- `TARGET_NUM_MOE_EXPERTS`
124+
- `TARGET_MOE_SHARED_EXPERT_INTERMEDIATE_SIZE`
123125
- `TARGET_NUM_LAYERS`
124126
- `LAYERS_TO_DROP` (comma separated, 1-indexed list of layer numbers to directly drop)
125127

modelopt/torch/nas/plugins/megatron.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
263263
self.linear_fc1.input_size = hidden_size
264264
self.linear_fc2.output_size = hidden_size
265265

266-
def modify(self, ffn_hidden_size_divisor: int) -> None:
266+
def modify(self, ffn_hidden_size_divisor: int, **kwargs) -> None:
267267
"""Modify the ffn_hidden_size hparam choices based on search space config."""
268268
hp_mlp = self.get_hparam(self.hparam_name)
269269
choices = {int(make_divisible(c, ffn_hidden_size_divisor)) for c in hp_mlp.choices} # type: ignore[arg-type]
@@ -937,7 +937,7 @@ def modify(
937937
hp.choices = list(set(hp.choices) & choices | {hp.original})
938938

939939
# Modify MLP hparam (regular or MoE)
940-
elif isinstance(self.mlp, (MLP, MoELayer)):
940+
if isinstance(self.mlp, (MLP, MoELayer)):
941941
self.mlp.modify(
942942
ffn_hidden_size_divisor=ffn_hidden_size_divisor,
943943
num_moe_experts_divisor=num_moe_experts_divisor,

tests/_test_utils/torch/megatron/models.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,12 @@ def get_mcore_gpt_model(
142142
normalization: str = "LayerNorm",
143143
transformer_impl: str = "modelopt" if HAS_TE else "local",
144144
use_cpu_initialization: bool = False,
145-
num_moe_experts: int | None = None,
146-
moe_grouped_gemm: bool = False,
147145
bf16: bool = True,
148146
use_te: bool = False,
147+
# MoE-specific parameters
148+
moe_grouped_gemm: bool = False,
149+
moe_shared_expert_intermediate_size: int | None = None,
150+
num_moe_experts: int | None = None,
149151
) -> GPTModel:
150152
assert activation_func in ["swiglu", "squared_relu"]
151153
assert normalization in ["LayerNorm", "RMSNorm"]
@@ -169,22 +171,25 @@ def squared_relu(x):
169171
expert_model_parallel_size=expert_model_parallel_size,
170172
expert_tensor_parallel_size=expert_tensor_parallel_size,
171173
sequence_parallel=False,
172-
moe_grouped_gemm=moe_grouped_gemm,
173174
num_layers=num_layers,
174175
num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage,
175176
num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage,
176177
hidden_size=hidden_size,
177178
num_attention_heads=num_attention_heads,
178179
num_query_groups=num_query_groups,
179180
ffn_hidden_size=ffn_hidden_size,
180-
num_moe_experts=num_moe_experts,
181181
activation_func=squared_relu if activation_func == "squared_relu" else F.silu,
182182
normalization=normalization,
183183
gated_linear_unit=(activation_func == "swiglu"),
184184
add_bias_linear=False,
185185
use_cpu_initialization=use_cpu_initialization,
186186
pipeline_dtype=torch.bfloat16 if bf16 else torch.float32,
187187
bf16=bf16,
188+
# MoE-specific parameters
189+
moe_grouped_gemm=moe_grouped_gemm,
190+
moe_router_dtype="fp32",
191+
moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size,
192+
num_moe_experts=num_moe_experts,
188193
)
189194

190195
if transformer_impl == "local":

tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def forward_loop(m):
134134

135135
# Assert weights are pruned correctly
136136
for layer in model.decoder.layers:
137-
print(rank, layer.mlp)
138137
assert layer.mlp.linear_fc1.weight.shape == (
139138
pruned_ffn * (2 if activation_func == "swiglu" else 1),
140139
pruned_hidden_size,
@@ -238,3 +237,116 @@ def test_mcore_gpt_pruning(
238237
),
239238
backend="nccl",
240239
)
240+
241+
242+
def _test_mcore_gpt_pruning_moe(ckpt_path, rank, size):
243+
num_layers = size
244+
hidden_size = 256
245+
ffn_hidden_size = 256
246+
num_moe_experts = 8
247+
moe_shared_expert_intermediate_size = 128
248+
max_sequence_length = 16
249+
vocab_size = 64
250+
batch_size = 2
251+
252+
def _get_model(initialize_megatron=True):
253+
model = get_mcore_gpt_model(
254+
tensor_model_parallel_size=1,
255+
pipeline_model_parallel_size=size,
256+
initialize_megatron=initialize_megatron,
257+
num_layers=num_layers,
258+
hidden_size=hidden_size,
259+
ffn_hidden_size=ffn_hidden_size,
260+
max_sequence_length=max_sequence_length,
261+
vocab_size=vocab_size,
262+
activation_func="squared_relu",
263+
num_moe_experts=num_moe_experts,
264+
moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size,
265+
).cuda()
266+
return model
267+
268+
model = _get_model()
269+
sd = model.state_dict()
270+
271+
def forward_loop(m):
272+
for _ in range(5):
273+
run_mcore_inference_with_dummy_input(m, batch_size, hidden_size)
274+
275+
pruned_ffn = ffn_hidden_size // 2
276+
pruned_hidden_size = hidden_size // 2
277+
pruned_num_moe_experts = num_moe_experts // 2
278+
pruned_moe_ffn = moe_shared_expert_intermediate_size // 2
279+
280+
export_config = {
281+
"ffn_hidden_size": pruned_ffn,
282+
"hidden_size": pruned_hidden_size,
283+
"num_moe_experts": pruned_num_moe_experts,
284+
"moe_shared_expert_intermediate_size": pruned_moe_ffn,
285+
}
286+
287+
mtp.prune(
288+
model,
289+
mode="mcore_minitron",
290+
constraints={"export_config": export_config},
291+
dummy_input=None, # Not used
292+
config={"scores_path": ckpt_path, "forward_loop": forward_loop},
293+
)
294+
295+
# Assert weights are pruned correctly
296+
for layer in model.decoder.layers:
297+
moe = layer.mlp
298+
assert moe.router.num_experts == pruned_num_moe_experts
299+
assert moe.router.weight.shape == (pruned_num_moe_experts, pruned_hidden_size)
300+
assert moe.experts.num_local_experts == pruned_num_moe_experts
301+
assert len(moe.experts.local_experts) == pruned_num_moe_experts
302+
for expert in moe.experts.local_experts:
303+
assert expert.linear_fc1.weight.shape == (pruned_ffn, pruned_hidden_size), (
304+
expert.linear_fc1.weight.shape,
305+
pruned_ffn,
306+
pruned_hidden_size,
307+
)
308+
assert expert.linear_fc2.weight.shape == (pruned_hidden_size, pruned_ffn), (
309+
expert.linear_fc2.weight.shape,
310+
pruned_hidden_size,
311+
pruned_ffn,
312+
)
313+
assert moe.shared_experts.linear_fc1.weight.shape == (
314+
pruned_moe_ffn,
315+
pruned_hidden_size,
316+
)
317+
assert moe.shared_experts.linear_fc2.weight.shape == (
318+
pruned_hidden_size,
319+
pruned_moe_ffn,
320+
)
321+
322+
# Assert model.config is updated for correct save/restoring
323+
assert model.config.ffn_hidden_size == pruned_ffn
324+
assert model.config.hidden_size == pruned_hidden_size
325+
assert model.config.num_moe_experts == pruned_num_moe_experts
326+
assert model.config.moe_shared_expert_intermediate_size == pruned_moe_ffn
327+
328+
# Assert forward pass works on the pruned model
329+
prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda()
330+
output = run_mcore_inference(model, prompt_tokens, pruned_hidden_size)
331+
332+
# Assert re-pruning from scores_path works without running the forward loop again
333+
model_rerun = _get_model(initialize_megatron=False)
334+
model_rerun.load_state_dict(sd)
335+
mtp.prune(
336+
model_rerun,
337+
mode="mcore_minitron",
338+
constraints={"export_config": export_config},
339+
dummy_input=None, # Not used
340+
config={"scores_path": ckpt_path},
341+
)
342+
343+
output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size)
344+
assert torch.allclose(output, output_rerun, atol=1e-5)
345+
346+
347+
def test_mcore_gpt_pruning_moe(tmp_path):
348+
spawn_multiprocess_job(
349+
size=torch.cuda.device_count(),
350+
job=partial(_test_mcore_gpt_pruning_moe, tmp_path / "minitron_scores.pth"),
351+
backend="nccl",
352+
)

0 commit comments

Comments
 (0)