Skip to content

Commit 136e94d

Browse files
Support checkpointing Minitron scores
Signed-off-by: Keval Morabia <[email protected]>
1 parent 115b145 commit 136e94d

File tree

11 files changed

+236
-82
lines changed

11 files changed

+236
-82
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Model Optimizer Changelog (Linux)
1717
- ``high_precision_dtype`` default to fp16 in ONNX quantization, i.e. quantized output model weights are now FP16 by default.
1818
- Upgrade TensorRT-LLM dependency to 1.1.0rc2.
1919
- Support Phi-4-multimodal and Qwen2.5-VL quantized HF checkpoint export in ``examples/vlm_ptq``.
20+
- Support storing and restoring Minitron pruning activations and scores for re-pruning without running the forward loop again.
2021
- Add Minitron pruning example for Megatron-LM framework. See ``examples/megatron-lm`` for more details.
2122

2223
0.35 (2025-09-04)

docs/source/guides/3_pruning.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Pruning
44

55
.. tip::
66

7-
Checkout `Llama 3.1 NeMo Minitron Pruning <https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/llama/pruning-distillation>`_ and
7+
Checkout `Qwen 3 NeMo Minitron Pruning & Distillation <https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation>`_ and
88
`ResNet20 on CIFAR-10 Notebook <https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/examples/pruning/cifar_resnet.ipynb>`_
99
for an end-to-end example of pruning.
1010

examples/llm_distill/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ Loss balancers:
144144

145145
Checkout the stand-alone distillation script in the [NVIDIA NeMo repository](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html).
146146

147-
You can also look at the tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/llama/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Llama 3.1 8B step-by-step in NeMo framework.
147+
You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.
148148

149149
## Knowledge Distillation (KD) for HuggingFace Models
150150

examples/pruning/README.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,27 @@ def forward_loop(model):
5959
evaluate_and_print_results(model, ...)
6060

6161

62-
# Specify the pruning constraints
62+
# Specify the pruning constraints (Check Support Matrix for available pruning dimensions)
6363
export_config = {
6464
"hidden_size": 3072,
6565
"ffn_hidden_size": 9216,
6666
}
6767

6868

6969
# Run the pruning process
70-
mtp.prune(
70+
# Save minitron scores at scores_path so we can re-run pruning with different export configs without running the forward loop again
71+
# NOTE: Skip scores_path on re-running if you want to change the dataset and re-calibrate
72+
model, pruning_scores = mtp.prune(
7173
model,
7274
mode="mcore_minitron",
7375
constraints={"export_config": export_config},
7476
dummy_input=None, # Not used
75-
config={"forward_loop": forward_loop},
77+
config={"forward_loop": forward_loop, "scores_path": "modelopt_minitron_scores.pth"},
7678
)
7779
```
7880

81+
If your model parameters are already sorted, you can skip the sorting step by setting `"skip_sorting": True` in `config` instead of passing `forward_loop`.
82+
7983
> [!Note]
8084
> Fine-tuning / distillation is required after pruning to recover the accuracy. Please refer to pruning [fine-tuning](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-fine-tuning) for more details.
8185
@@ -91,11 +95,11 @@ mtp.prune(
9195
9296
## Examples
9397

94-
### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Llama 3.1, Nemotron Nano)
98+
### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Qwen 3, Nemotron Nano)
9599

96100
Checkout the Minitron pruning example for the [Megatron-LM Framework](../megatron-lm/README.md#-pruning) and [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/pruning/pruning.html) which showcases the usage of the powerful Minitron pruning algorithm developed by NVIDIA Research for pruning LLMs like Llama 3.1 8B, Qwen 3 8B, Nemotron Nano 12B v2, etc.
97101

98-
You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/llama/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Llama 3.1 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.
102+
You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.
99103

100104
Some of the models pruned using Minitron method followed by distillation and post-training are:
101105

modelopt/torch/nas/plugins/megatron.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,10 +1234,10 @@ def _emb_layernorm_forward_hook(self, module, input, output) -> None:
12341234
output = output.to(torch.float32) # use full precision to avoid overflow
12351235
activations = output.abs().mean(dim=0) # [batch_size, hidden_size]
12361236
activations = activations.pow(2).sum(dim=0) # [hidden_size]
1237-
if module not in self._activations:
1238-
self._activations[module] = activations
1237+
if id(module) not in self._activations:
1238+
self._activations[id(module)] = activations
12391239
else:
1240-
self._activations[module] += activations
1240+
self._activations[id(module)] += activations
12411241

12421242
def _estimate_hidden_size_importance(self) -> TracedHp.Importance:
12431243
"""Return the activation magnitude-based importance of the hidden_size."""
@@ -1284,16 +1284,14 @@ def modify(
12841284
mamba_head_dim_divisor=mamba_head_dim_divisor,
12851285
)
12861286

1287-
def _export_drop_layers(self) -> None:
1288-
"""Drop layers during export if num_layers hparam is set to a smaller value during pruning."""
1287+
def _get_layer_scores(self) -> dict[int, torch.Tensor]:
1288+
"""Get the layer scores (1-indexed) from the module."""
12891289
num_layers_hp = self.get_hparam("num_layers")
1290-
if num_layers_hp.active == num_layers_hp.max: # no depth pruning
1291-
return
12921290

12931291
for layer in self.decoder.layers:
12941292
assert layer._scores > 0, "No scores collected for importance estimation."
12951293

1296-
# gather layer scores from all TP regions
1294+
# gather layer scores from all PP ranks
12971295
layer_scores = {}
12981296
for layer in self.decoder.layers:
12991297
layer_scores[layer.layer_number] = layer._scores
@@ -1302,10 +1300,19 @@ def _export_drop_layers(self) -> None:
13021300
all_pp_layer_scores, layer_scores, group=get_pipeline_model_parallel_group()
13031301
)
13041302
layer_scores = {k: v for d in all_pp_layer_scores for k, v in d.items()} # type: ignore[attr-defined]
1305-
print_rank_0(f"Layerwise scores for depth pruning: {layer_scores}")
1303+
print_rank_0(f"Layerwise scores (1-indexed, higher is better): {layer_scores}")
13061304
assert sorted(layer_scores.keys()) == list(range(1, num_layers_hp.max + 1)) # type: ignore[arg-type]
13071305

1306+
return layer_scores
1307+
1308+
def _export_drop_layers(self) -> None:
1309+
"""Drop layers during export if num_layers hparam is set to a smaller value during pruning."""
1310+
num_layers_hp = self.get_hparam("num_layers")
1311+
if num_layers_hp.active == num_layers_hp.max: # no depth pruning
1312+
return
1313+
13081314
# sort layers by scores and drop the lowest ones
1315+
layer_scores = self._get_layer_scores()
13091316
sorted_layers = sorted(layer_scores.items(), key=lambda x: x[1], reverse=True)
13101317
layers_to_drop = [layer for layer, _ in sorted_layers[num_layers_hp.active :]] # type: ignore[misc]
13111318
drop_mcore_language_model_layers(self, layers_to_drop=layers_to_drop)
@@ -1337,6 +1344,47 @@ def freeze(self) -> None:
13371344
for layer in self.decoder.layers:
13381345
layer.freeze()
13391346

1347+
def get_activations_and_layer_scores(
1348+
self,
1349+
) -> tuple[list[dict[str, torch.Tensor]], dict[int, torch.Tensor]]:
1350+
"""Get the per-rank activations and layer scores from the module."""
1351+
local_activations = {}
1352+
for n, m in self.named_modules():
1353+
if hasattr(m, "_activations"):
1354+
local_activations[n] = m._activations
1355+
activations_per_rank = dist.allgather(
1356+
local_activations, group=get_pipeline_model_parallel_group()
1357+
)
1358+
assert len(activations_per_rank) == get_pipeline_model_parallel_world_size()
1359+
1360+
layer_scores = self._get_layer_scores()
1361+
1362+
return activations_per_rank, layer_scores
1363+
1364+
def set_activations_and_layer_scores(
1365+
self,
1366+
activations_per_rank: list[dict[str, torch.Tensor]],
1367+
layer_scores: dict[int, torch.Tensor],
1368+
) -> None:
1369+
"""Set the pre-computed layer_scores and per-rank activations instead of running forward.
1370+
1371+
Args:
1372+
layer_scores: Dict from layer_number (1-indexed) to score.
1373+
activations_per_rank: List of dicts from module name to activations. Should match PP size.
1374+
"""
1375+
rank = get_pipeline_model_parallel_rank()
1376+
pp_size = get_pipeline_model_parallel_world_size()
1377+
assert len(activations_per_rank) == pp_size, (
1378+
len(activations_per_rank),
1379+
activations_per_rank,
1380+
pp_size,
1381+
)
1382+
for layer in self.decoder.layers:
1383+
layer._scores = layer_scores[layer.layer_number]
1384+
for n, m in self.named_modules():
1385+
if hasattr(m, "_activations"):
1386+
m._activations = activations_per_rank[rank][n]
1387+
13401388

13411389
def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[int]) -> None:
13421390
"""Remove given layers (1-indexed) of the model (works with TP and/or PP).

modelopt/torch/opt/searcher.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from collections.abc import Callable
2828
from contextlib import nullcontext
2929
from typing import Any, final
30+
from warnings import warn
3031

3132
import numpy as np
3233
import pulp
@@ -239,7 +240,11 @@ def load_search_checkpoint(self) -> bool:
239240
"""Load function for search checkpoint returning indicator whether checkpoint was loaded."""
240241
# check if checkpoint exists
241242
checkpoint: str | None = self.config["checkpoint"]
242-
if checkpoint is None or not os.path.exists(checkpoint):
243+
if checkpoint is None:
244+
return False
245+
if not os.path.exists(checkpoint):
246+
if dist.is_master():
247+
warn(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.")
243248
return False
244249

245250
# iterate through state dict and load keys
@@ -250,14 +255,16 @@ def load_search_checkpoint(self) -> bool:
250255
setattr(self, key, state)
251256
return True
252257

253-
def save_search_checkpoint(self) -> None:
258+
def save_search_checkpoint(self, verbose=False) -> None:
254259
"""Save function for search checkpoint."""
255260
# check if save requirements are satisfied
256261
checkpoint: str | None = self.config["checkpoint"]
257262
if checkpoint is None or not dist.is_master():
258263
return
259264

260265
# save state dict
266+
if verbose:
267+
print(f"Saving searcher state to {checkpoint}...")
261268
save_dirname, _ = os.path.split(checkpoint)
262269
if save_dirname:
263270
os.makedirs(save_dirname, exist_ok=True)

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
Actual dynamic module implementations are at :mod:`modelopt.torch.nas.plugins.megatron`.
2525
"""
2626

27+
import copy
28+
2729
import torch
2830
from pydantic import create_model
2931

@@ -37,6 +39,7 @@
3739
from modelopt.torch.nas.registry import DMRegistry
3840
from modelopt.torch.nas.utils import sort_parameters
3941
from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules
42+
from modelopt.torch.opt.dynamic import DynamicModule
4043
from modelopt.torch.opt.searcher import BaseSearcher, SearchConfig, SearchStateDict
4144
from modelopt.torch.opt.utils import named_hparams
4245
from modelopt.torch.utils import print_rank_0
@@ -60,22 +63,28 @@
6063
class MCoreMinitronSearcher(BaseSearcher):
6164
"""Searcher for Minitron pruning algorithm."""
6265

66+
activations_per_rank: list[dict[str, torch.Tensor]]
67+
layer_scores: dict[int, torch.Tensor]
68+
6369
@property
6470
def default_search_config(self) -> SearchConfig:
6571
"""Get the default config for the searcher."""
66-
return {**super().default_search_config, "max_iter_data_loader": 1024}
72+
return {
73+
**super().default_search_config,
74+
"max_iter_data_loader": 1024,
75+
"skip_sorting": False,
76+
"scores_path": None,
77+
}
6778

6879
@property
6980
def default_state_dict(self) -> SearchStateDict:
70-
"""Return default state dict."""
71-
return {} # Not used
81+
"""Return default state dict for importance scores and activations from forward loop."""
82+
return {"activations_per_rank": [], "layer_scores": {}}
7283

7384
def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
7485
"""Sanitize the search config dict."""
7586
config = super().sanitize_search_config(config)
76-
assert config["data_loader"] or config["forward_loop"], (
77-
"Data loader or forward loop must be provided for importance estimation!"
78-
)
87+
config["checkpoint"] = config["scores_path"]
7988
return config
8089

8190
def before_search(self) -> None:
@@ -87,10 +96,11 @@ def before_search(self) -> None:
8796
"Only `export_config` constraint is supported for pruning!"
8897
)
8998

99+
self.constraints["export_config"] = copy.deepcopy(self.constraints["export_config"])
90100
export_config = self.constraints["export_config"]
91101
assert isinstance(export_config, dict) # to keep mypy happy
92102
assert export_config.keys() <= SUPPORTED_HPARAMS, (
93-
f"Only {SUPPORTED_HPARAMS} are supported for pruning!"
103+
f"Only {SUPPORTED_HPARAMS} are supported for pruning! Received: {export_config.keys()}"
94104
)
95105

96106
assert ("num_attention_heads" in export_config and "num_query_groups" in export_config) or (
@@ -124,14 +134,37 @@ def before_search(self) -> None:
124134
def run_search(self) -> None:
125135
"""Run actual search."""
126136
# Run forward loop to collect activations and sort parameters
127-
assert self.forward_loop is not None
128-
is_training = self.model.training
129-
self.model.eval()
130-
print_rank_0("Running forward loop...")
131-
with torch.no_grad():
132-
self.forward_loop(self.model)
133-
sort_parameters(self.model, self.hps_to_sort, verbose=True)
134-
self.model.train(is_training)
137+
unwrapped_model = self.model
138+
while hasattr(unwrapped_model, "module"):
139+
unwrapped_model = unwrapped_model.module
140+
assert isinstance(unwrapped_model, DynamicModule) and hasattr(
141+
unwrapped_model, "set_activations_and_layer_scores"
142+
), unwrapped_model
143+
144+
if self.layer_scores and self.activations_per_rank: # Available from checkpoint
145+
print_rank_0("Loading activations and scores per rank from checkpoint...")
146+
unwrapped_model.set_activations_and_layer_scores(
147+
self.activations_per_rank, self.layer_scores
148+
)
149+
elif not self.config["skip_sorting"]:
150+
print_rank_0("Running forward loop...")
151+
assert self.forward_loop is not None
152+
is_training = self.model.training
153+
self.model.eval()
154+
with torch.no_grad():
155+
self.forward_loop(self.model)
156+
self.model.train(is_training)
157+
158+
# Store activations and layer scores for re-pruning with different export configs
159+
self.activations_per_rank, self.layer_scores = (
160+
unwrapped_model.get_activations_and_layer_scores()
161+
)
162+
self.save_search_checkpoint(verbose=True)
163+
164+
if self.config["skip_sorting"]:
165+
print_rank_0("Skipping sorting parameters...")
166+
else:
167+
sort_parameters(self.model, self.hps_to_sort, verbose=True)
135168

136169
# Prune homogeneously
137170
export_config = self.constraints["export_config"]

modelopt/torch/utils/distributed.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,23 +124,23 @@ def broadcast(obj: Any, src: int = 0, group=None) -> Any:
124124
return obj
125125

126126

127-
def _allgather(tensors: list[torch.Tensor], tensor: torch.Tensor) -> None:
127+
def _allgather(tensors: list[torch.Tensor], tensor: torch.Tensor, group=None) -> None:
128128
if backend() == "torch":
129-
torch.distributed.all_gather(tensors, tensor)
129+
torch.distributed.all_gather(tensors, tensor, group)
130130

131131

132-
def allgather(obj: Any) -> list[Any]:
132+
def allgather(obj: Any, group=None) -> list[Any]:
133133
"""Gathers an object from all processes into a list."""
134-
if size() == 1:
134+
if size(group) == 1:
135135
return [obj]
136136

137137
# serialize
138138
tensor = _serialize(obj).cuda()
139139

140140
# gather the tensor size
141141
tensor_size = torch.LongTensor([tensor.numel()]).cuda()
142-
tensor_sizes = [torch.LongTensor([0]).cuda() for _ in range(size())]
143-
_allgather(tensor_sizes, tensor_size)
142+
tensor_sizes = [torch.LongTensor([0]).cuda() for _ in range(size(group))]
143+
_allgather(tensor_sizes, tensor_size, group)
144144
tensor_sizes = [int(tensor_size.item()) for tensor_size in tensor_sizes]
145145
max_size = max(tensor_sizes)
146146

@@ -149,7 +149,7 @@ def allgather(obj: Any) -> list[Any]:
149149
if tensor_size != max_size:
150150
padding = torch.ByteTensor(size=(max_size - tensor_size,)).cuda()
151151
tensor = torch.cat((tensor, padding), dim=0)
152-
_allgather(tensors, tensor)
152+
_allgather(tensors, tensor, group)
153153

154154
# deserialize
155155
objs = []
@@ -159,9 +159,9 @@ def allgather(obj: Any) -> list[Any]:
159159
return objs
160160

161161

162-
def allreduce(obj: Any, reduction: str = "sum") -> Any:
162+
def allreduce(obj: Any, reduction: str = "sum", group=None) -> Any:
163163
"""Reduces an object from all processes."""
164-
objs = allgather(obj)
164+
objs = allgather(obj, group)
165165
if reduction == "sum":
166166
return sum(objs)
167167
else:

0 commit comments

Comments
 (0)