Skip to content

Commit 3d7f8af

Browse files
Support checkpointing Minitron
scores Signed-off-by: Keval Morabia <[email protected]>
1 parent 115b145 commit 3d7f8af

File tree

10 files changed

+153
-70
lines changed

10 files changed

+153
-70
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Model Optimizer Changelog (Linux)
1717
- ``high_precision_dtype`` default to fp16 in ONNX quantization, i.e. quantized output model weights are now FP16 by default.
1818
- Upgrade TensorRT-LLM dependency to 1.1.0rc2.
1919
- Support Phi-4-multimodal and Qwen2.5-VL quantized HF checkpoint export in ``examples/vlm_ptq``.
20+
- Support storing and restoring Minitron pruning activations and scores for re-pruning without running the forward loop again.
2021
- Add Minitron pruning example for Megatron-LM framework. See ``examples/megatron-lm`` for more details.
2122

2223
0.35 (2025-09-04)

docs/source/guides/3_pruning.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Pruning
44

55
.. tip::
66

7-
Checkout `Llama 3.1 NeMo Minitron Pruning <https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/llama/pruning-distillation>`_ and
7+
Checkout `Qwen 3 NeMo Minitron Pruning & Distillation <https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation>`_ and
88
`ResNet20 on CIFAR-10 Notebook <https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/examples/pruning/cifar_resnet.ipynb>`_
99
for an end-to-end example of pruning.
1010

examples/llm_distill/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ Loss balancers:
144144

145145
Checkout the stand-alone distillation script in the [NVIDIA NeMo repository](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html).
146146

147-
You can also look at the tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/llama/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Llama 3.1 8B step-by-step in NeMo framework.
147+
You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.
148148

149149
## Knowledge Distillation (KD) for HuggingFace Models
150150

examples/pruning/README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,14 @@ export_config = {
6767

6868

6969
# Run the pruning process
70+
# Save minitron scores so we can re-run pruning with different export configs without running the forward loop again
71+
# NOTE: Skip checkpoint path on re-running if you want to change the dataset
7072
mtp.prune(
7173
model,
7274
mode="mcore_minitron",
7375
constraints={"export_config": export_config},
7476
dummy_input=None, # Not used
75-
config={"forward_loop": forward_loop},
77+
config={"forward_loop": forward_loop, "checkpoint": "modelopt_minitron_scores.pth"},
7678
)
7779
```
7880

@@ -91,11 +93,11 @@ mtp.prune(
9193
9294
## Examples
9395

94-
### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Llama 3.1, Nemotron Nano)
96+
### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Qwen 3, Nemotron Nano)
9597

9698
Checkout the Minitron pruning example for the [Megatron-LM Framework](../megatron-lm/README.md#-pruning) and [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/pruning/pruning.html) which showcases the usage of the powerful Minitron pruning algorithm developed by NVIDIA Research for pruning LLMs like Llama 3.1 8B, Qwen 3 8B, Nemotron Nano 12B v2, etc.
9799

98-
You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/llama/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Llama 3.1 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.
100+
You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.
99101

100102
Some of the models pruned using Minitron method followed by distillation and post-training are:
101103

modelopt/torch/nas/plugins/megatron.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,10 +1234,10 @@ def _emb_layernorm_forward_hook(self, module, input, output) -> None:
12341234
output = output.to(torch.float32) # use full precision to avoid overflow
12351235
activations = output.abs().mean(dim=0) # [batch_size, hidden_size]
12361236
activations = activations.pow(2).sum(dim=0) # [hidden_size]
1237-
if module not in self._activations:
1238-
self._activations[module] = activations
1237+
if id(module) not in self._activations:
1238+
self._activations[id(module)] = activations
12391239
else:
1240-
self._activations[module] += activations
1240+
self._activations[id(module)] += activations
12411241

12421242
def _estimate_hidden_size_importance(self) -> TracedHp.Importance:
12431243
"""Return the activation magnitude-based importance of the hidden_size."""
@@ -1293,7 +1293,7 @@ def _export_drop_layers(self) -> None:
12931293
for layer in self.decoder.layers:
12941294
assert layer._scores > 0, "No scores collected for importance estimation."
12951295

1296-
# gather layer scores from all TP regions
1296+
# gather layer scores from all PP ranks
12971297
layer_scores = {}
12981298
for layer in self.decoder.layers:
12991299
layer_scores[layer.layer_number] = layer._scores
@@ -1302,7 +1302,7 @@ def _export_drop_layers(self) -> None:
13021302
all_pp_layer_scores, layer_scores, group=get_pipeline_model_parallel_group()
13031303
)
13041304
layer_scores = {k: v for d in all_pp_layer_scores for k, v in d.items()} # type: ignore[attr-defined]
1305-
print_rank_0(f"Layerwise scores for depth pruning: {layer_scores}")
1305+
print_rank_0(f"Layerwise scores (1-indexed) for depth pruning: {layer_scores}")
13061306
assert sorted(layer_scores.keys()) == list(range(1, num_layers_hp.max + 1)) # type: ignore[arg-type]
13071307

13081308
# sort layers by scores and drop the lowest ones

modelopt/torch/opt/searcher.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from collections.abc import Callable
2828
from contextlib import nullcontext
2929
from typing import Any, final
30+
from warnings import warn
3031

3132
import numpy as np
3233
import pulp
@@ -239,7 +240,11 @@ def load_search_checkpoint(self) -> bool:
239240
"""Load function for search checkpoint returning indicator whether checkpoint was loaded."""
240241
# check if checkpoint exists
241242
checkpoint: str | None = self.config["checkpoint"]
242-
if checkpoint is None or not os.path.exists(checkpoint):
243+
if checkpoint is None:
244+
return False
245+
if not os.path.exists(checkpoint):
246+
if dist.is_master():
247+
warn(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.")
243248
return False
244249

245250
# iterate through state dict and load keys
@@ -250,14 +255,16 @@ def load_search_checkpoint(self) -> bool:
250255
setattr(self, key, state)
251256
return True
252257

253-
def save_search_checkpoint(self) -> None:
258+
def save_search_checkpoint(self, verbose=False) -> None:
254259
"""Save function for search checkpoint."""
255260
# check if save requirements are satisfied
256261
checkpoint: str | None = self.config["checkpoint"]
257262
if checkpoint is None or not dist.is_master():
258263
return
259264

260265
# save state dict
266+
if verbose:
267+
print(f"Saving searcher state to {checkpoint}...")
261268
save_dirname, _ = os.path.split(checkpoint)
262269
if save_dirname:
263270
os.makedirs(save_dirname, exist_ok=True)

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
Actual dynamic module implementations are at :mod:`modelopt.torch.nas.plugins.megatron`.
2525
"""
2626

27+
import copy
28+
2729
import torch
2830
from pydantic import create_model
2931

@@ -39,6 +41,7 @@
3941
from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules
4042
from modelopt.torch.opt.searcher import BaseSearcher, SearchConfig, SearchStateDict
4143
from modelopt.torch.opt.utils import named_hparams
44+
from modelopt.torch.utils import distributed as dist
4245
from modelopt.torch.utils import print_rank_0
4346

4447
from ..fastnas import FastNASModeDescriptor
@@ -60,23 +63,19 @@
6063
class MCoreMinitronSearcher(BaseSearcher):
6164
"""Searcher for Minitron pruning algorithm."""
6265

66+
activations: dict[int, dict[str, torch.Tensor]]
67+
scores: dict[int, dict[str, torch.Tensor]]
68+
ckpt_world_size: int
69+
6370
@property
6471
def default_search_config(self) -> SearchConfig:
6572
"""Get the default config for the searcher."""
6673
return {**super().default_search_config, "max_iter_data_loader": 1024}
6774

6875
@property
6976
def default_state_dict(self) -> SearchStateDict:
70-
"""Return default state dict."""
71-
return {} # Not used
72-
73-
def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
74-
"""Sanitize the search config dict."""
75-
config = super().sanitize_search_config(config)
76-
assert config["data_loader"] or config["forward_loop"], (
77-
"Data loader or forward loop must be provided for importance estimation!"
78-
)
79-
return config
77+
"""Return default state dict for importance scores and activations from forward loop."""
78+
return {"activations": {}, "scores": {}, "ckpt_world_size": dist.size()}
8079

8180
def before_search(self) -> None:
8281
"""Optional pre-processing steps before the search."""
@@ -87,10 +86,11 @@ def before_search(self) -> None:
8786
"Only `export_config` constraint is supported for pruning!"
8887
)
8988

89+
self.constraints["export_config"] = copy.deepcopy(self.constraints["export_config"])
9090
export_config = self.constraints["export_config"]
9191
assert isinstance(export_config, dict) # to keep mypy happy
9292
assert export_config.keys() <= SUPPORTED_HPARAMS, (
93-
f"Only {SUPPORTED_HPARAMS} are supported for pruning!"
93+
f"Only {SUPPORTED_HPARAMS} are supported for pruning! Received: {export_config.keys()}"
9494
)
9595

9696
assert ("num_attention_heads" in export_config and "num_query_groups" in export_config) or (
@@ -124,14 +124,47 @@ def before_search(self) -> None:
124124
def run_search(self) -> None:
125125
"""Run actual search."""
126126
# Run forward loop to collect activations and sort parameters
127-
assert self.forward_loop is not None
128-
is_training = self.model.training
129-
self.model.eval()
130-
print_rank_0("Running forward loop...")
131-
with torch.no_grad():
132-
self.forward_loop(self.model)
127+
if self.scores and self.activations: # Available from checkpoint
128+
print_rank_0("Loading activations and scores per rank from checkpoint...")
129+
assert self.ckpt_world_size == dist.size(), "World size mismatch!"
130+
rank = dist.rank()
131+
for n, m in self.model.named_modules():
132+
if hasattr(m, "_scores"):
133+
m._scores = self.scores[rank][n]
134+
if hasattr(m, "_activations"):
135+
m._activations = self.activations[rank][n]
136+
else:
137+
print_rank_0("Running forward loop...")
138+
assert self.forward_loop is not None
139+
is_training = self.model.training
140+
self.model.eval()
141+
with torch.no_grad():
142+
self.forward_loop(self.model)
143+
self.model.train(is_training)
144+
145+
# Store activations and layer scores for re-pruning with different export configs
146+
rank = dist.rank()
147+
rank_scores = {}
148+
rank_activations = {}
149+
for n, m in self.model.named_modules():
150+
if hasattr(m, "_scores"):
151+
rank_scores[n] = m._scores
152+
if hasattr(m, "_activations"):
153+
rank_activations[n] = m._activations
154+
155+
# Gather scores and activations from all ranks to rank 0
156+
all_scores = dist.allgather(rank_scores)
157+
all_activations = dist.allgather(rank_activations)
158+
159+
# Store all ranks' data in the searcher's state
160+
for r in range(dist.size()):
161+
self.scores[r] = all_scores[r]
162+
self.activations[r] = all_activations[r]
163+
164+
self.save_search_checkpoint(verbose=True)
165+
dist.barrier()
166+
133167
sort_parameters(self.model, self.hps_to_sort, verbose=True)
134-
self.model.train(is_training)
135168

136169
# Prune homogeneously
137170
export_config = self.constraints["export_config"]

tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def _test_mcore_gpt_pruning(
4141
pruned_hidden_size_div,
4242
pruned_num_layers_div,
4343
uneven_pp,
44+
ckpt_path,
4445
rank,
4546
size,
4647
):
@@ -66,22 +67,26 @@ def _test_mcore_gpt_pruning(
6667
else:
6768
raise ValueError(f"Unsupported size {size}")
6869

69-
model = get_mcore_gpt_model(
70-
tensor_model_parallel_size=1,
71-
pipeline_model_parallel_size=size,
72-
initialize_megatron=True,
73-
num_layers=num_layers,
74-
hidden_size=hidden_size,
75-
num_attention_heads=num_attention_heads,
76-
num_query_groups=num_query_groups,
77-
ffn_hidden_size=ffn_hidden_size,
78-
max_sequence_length=max_sequence_length,
79-
vocab_size=vocab_size,
80-
activation_func=activation_func,
81-
normalization=normalization,
82-
num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage,
83-
num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage,
84-
)
70+
def _get_model(initialize_megatron=True):
71+
model = get_mcore_gpt_model(
72+
tensor_model_parallel_size=1,
73+
pipeline_model_parallel_size=size,
74+
initialize_megatron=initialize_megatron,
75+
num_layers=num_layers,
76+
hidden_size=hidden_size,
77+
num_attention_heads=num_attention_heads,
78+
num_query_groups=num_query_groups,
79+
ffn_hidden_size=ffn_hidden_size,
80+
max_sequence_length=max_sequence_length,
81+
vocab_size=vocab_size,
82+
activation_func=activation_func,
83+
normalization=normalization,
84+
num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage,
85+
num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage,
86+
)
87+
return model
88+
89+
model = _get_model()
8590

8691
def forward_loop(m):
8792
for _ in range(5):
@@ -110,7 +115,7 @@ def forward_loop(m):
110115
mode="mcore_minitron",
111116
constraints={"export_config": export_config},
112117
dummy_input=None, # Not used
113-
config={"forward_loop": forward_loop},
118+
config={"forward_loop": forward_loop, "checkpoint": ckpt_path},
114119
)
115120

116121
# Assert weights are pruned correctly
@@ -139,6 +144,17 @@ def forward_loop(m):
139144
# Assert forward pass works on the pruned model
140145
run_mcore_inference_with_dummy_input(model, batch_size, pruned_hidden_size)
141146

147+
# Assert re-pruning from checkpoint works without running the forward loop again
148+
if ckpt_path:
149+
model = _get_model(initialize_megatron=False)
150+
mtp.prune(
151+
model,
152+
mode="mcore_minitron",
153+
constraints={"export_config": export_config},
154+
dummy_input=None, # Not used
155+
config={"checkpoint": ckpt_path},
156+
)
157+
142158

143159
@pytest.mark.parametrize(
144160
(
@@ -152,16 +168,18 @@ def forward_loop(m):
152168
"hidden_size_div",
153169
"num_layers_div",
154170
"uneven_pp",
171+
"test_ckpt",
155172
),
156173
[
157-
(8, 8, "squared_relu", "LayerNorm", 4, 1, 1, 1, 1, False), # MHA - pruned ffn/4
158-
(8, 4, "squared_relu", "RMSNorm", 1, 2, 2, 1, 1, False), # GQA - pruned attention/2
159-
(8, 4, "swiglu", "RMSNorm", 1, 1, 1, 4, 1, False), # GQA - pruned hidden_size/4
160-
(8, 8, "swiglu", "LayerNorm", 1, 1, 1, 1, 2, False), # MHA - pruned num_layers/2
161-
(8, 4, "swiglu", "RMSNorm", 2, 2, 2, 2, 2, True), # GQA - pruned all/2, uneven pp
174+
(8, 8, "squared_relu", "LayerNorm", 4, 1, 1, 1, 1, False, False), # MHA - pruned ffn/4
175+
(8, 4, "squared_relu", "RMSNorm", 1, 2, 2, 1, 1, False, False), # GQA - pruned attention/2
176+
(8, 4, "swiglu", "RMSNorm", 1, 1, 1, 4, 1, False, False), # GQA - pruned hidden_size/4
177+
(8, 8, "swiglu", "LayerNorm", 1, 1, 1, 1, 2, False, False), # MHA - pruned num_layers/2
178+
(8, 4, "swiglu", "RMSNorm", 2, 2, 2, 2, 2, True, True), # GQA - pruned all/2, uneven pp
162179
],
163180
)
164181
def test_mcore_gpt_pruning(
182+
tmp_path,
165183
num_attention_heads,
166184
num_query_groups,
167185
activation_func,
@@ -172,6 +190,7 @@ def test_mcore_gpt_pruning(
172190
hidden_size_div,
173191
num_layers_div,
174192
uneven_pp,
193+
test_ckpt,
175194
):
176195
spawn_multiprocess_job(
177196
size=torch.cuda.device_count(),
@@ -187,6 +206,7 @@ def test_mcore_gpt_pruning(
187206
hidden_size_div,
188207
num_layers_div,
189208
uneven_pp,
209+
tmp_path / "modelopt_minitron_scores.pth" if test_ckpt else None,
190210
),
191211
backend="nccl",
192212
)

0 commit comments

Comments
 (0)