Skip to content

Commit 651f9d2

Browse files
Support checkpointing Minitron scores
Signed-off-by: Keval Morabia <[email protected]>
1 parent 115b145 commit 651f9d2

File tree

11 files changed

+243
-84
lines changed

11 files changed

+243
-84
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Model Optimizer Changelog (Linux)
1717
- ``high_precision_dtype`` default to fp16 in ONNX quantization, i.e. quantized output model weights are now FP16 by default.
1818
- Upgrade TensorRT-LLM dependency to 1.1.0rc2.
1919
- Support Phi-4-multimodal and Qwen2.5-VL quantized HF checkpoint export in ``examples/vlm_ptq``.
20+
- Support storing and restoring Minitron pruning activations and scores for re-pruning without running the forward loop again.
2021
- Add Minitron pruning example for Megatron-LM framework. See ``examples/megatron-lm`` for more details.
2122

2223
0.35 (2025-09-04)

docs/source/guides/3_pruning.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Pruning
44

55
.. tip::
66

7-
Checkout `Llama 3.1 NeMo Minitron Pruning <https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/llama/pruning-distillation>`_ and
7+
Checkout `Qwen 3 NeMo Minitron Pruning & Distillation <https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation>`_ and
88
`ResNet20 on CIFAR-10 Notebook <https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/examples/pruning/cifar_resnet.ipynb>`_
99
for an end-to-end example of pruning.
1010

examples/llm_distill/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ Loss balancers:
144144

145145
Checkout the stand-alone distillation script in the [NVIDIA NeMo repository](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html).
146146

147-
You can also look at the tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/llama/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Llama 3.1 8B step-by-step in NeMo framework.
147+
You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.
148148

149149
## Knowledge Distillation (KD) for HuggingFace Models
150150

examples/pruning/README.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,27 @@ def forward_loop(model):
5959
evaluate_and_print_results(model, ...)
6060

6161

62-
# Specify the pruning constraints
62+
# Specify the pruning constraints (Check Support Matrix for available pruning dimensions)
6363
export_config = {
6464
"hidden_size": 3072,
6565
"ffn_hidden_size": 9216,
6666
}
6767

6868

6969
# Run the pruning process
70-
mtp.prune(
70+
# Save minitron scores at scores_path so we can re-run pruning with different export configs without running the forward loop again
71+
# NOTE: Skip scores_path on re-running if you want to change the dataset and re-calibrate
72+
model, pruning_scores = mtp.prune(
7173
model,
7274
mode="mcore_minitron",
7375
constraints={"export_config": export_config},
7476
dummy_input=None, # Not used
75-
config={"forward_loop": forward_loop},
77+
config={"forward_loop": forward_loop, "scores_path": "modelopt_minitron_scores.pth"},
7678
)
7779
```
7880

81+
If your model parameters are already sorted, you can skip the sorting step by setting `"skip_sorting": True` in `config` instead of passing `forward_loop`.
82+
7983
> [!Note]
8084
> Fine-tuning / distillation is required after pruning to recover the accuracy. Please refer to pruning [fine-tuning](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-fine-tuning) for more details.
8185
@@ -91,11 +95,11 @@ mtp.prune(
9195
9296
## Examples
9397

94-
### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Llama 3.1, Nemotron Nano)
98+
### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Qwen 3, Nemotron Nano)
9599

96100
Checkout the Minitron pruning example for the [Megatron-LM Framework](../megatron-lm/README.md#-pruning) and [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/pruning/pruning.html) which showcases the usage of the powerful Minitron pruning algorithm developed by NVIDIA Research for pruning LLMs like Llama 3.1 8B, Qwen 3 8B, Nemotron Nano 12B v2, etc.
97101

98-
You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/llama/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Llama 3.1 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.
102+
You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.
99103

100104
Some of the models pruned using Minitron method followed by distillation and post-training are:
101105

modelopt/torch/nas/plugins/megatron.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,10 @@ def _layer_imp_forward_hook(self, module, args, kwargs, output) -> None:
649649
if hidden_states.shape[-1] != self.max_hidden_size:
650650
return
651651

652+
# use full precision to avoid overflow
653+
hidden_states = hidden_states.to(torch.float32)
654+
output = output.to(torch.float32)
655+
652656
with torch.no_grad():
653657
# Lower cosine_similarity means higher importance hence use 1 - cosine_similarity
654658
score = 1 - F.cosine_similarity(hidden_states, output, dim=2).mean()
@@ -1234,10 +1238,10 @@ def _emb_layernorm_forward_hook(self, module, input, output) -> None:
12341238
output = output.to(torch.float32) # use full precision to avoid overflow
12351239
activations = output.abs().mean(dim=0) # [batch_size, hidden_size]
12361240
activations = activations.pow(2).sum(dim=0) # [hidden_size]
1237-
if module not in self._activations:
1238-
self._activations[module] = activations
1241+
if id(module) not in self._activations:
1242+
self._activations[id(module)] = activations
12391243
else:
1240-
self._activations[module] += activations
1244+
self._activations[id(module)] += activations
12411245

12421246
def _estimate_hidden_size_importance(self) -> TracedHp.Importance:
12431247
"""Return the activation magnitude-based importance of the hidden_size."""
@@ -1284,16 +1288,14 @@ def modify(
12841288
mamba_head_dim_divisor=mamba_head_dim_divisor,
12851289
)
12861290

1287-
def _export_drop_layers(self) -> None:
1288-
"""Drop layers during export if num_layers hparam is set to a smaller value during pruning."""
1291+
def _get_layer_scores(self) -> dict[int, torch.Tensor]:
1292+
"""Get the layer scores (1-indexed) from the module."""
12891293
num_layers_hp = self.get_hparam("num_layers")
1290-
if num_layers_hp.active == num_layers_hp.max: # no depth pruning
1291-
return
12921294

12931295
for layer in self.decoder.layers:
12941296
assert layer._scores > 0, "No scores collected for importance estimation."
12951297

1296-
# gather layer scores from all TP regions
1298+
# gather layer scores from all PP ranks
12971299
layer_scores = {}
12981300
for layer in self.decoder.layers:
12991301
layer_scores[layer.layer_number] = layer._scores
@@ -1302,10 +1304,19 @@ def _export_drop_layers(self) -> None:
13021304
all_pp_layer_scores, layer_scores, group=get_pipeline_model_parallel_group()
13031305
)
13041306
layer_scores = {k: v for d in all_pp_layer_scores for k, v in d.items()} # type: ignore[attr-defined]
1305-
print_rank_0(f"Layerwise scores for depth pruning: {layer_scores}")
1307+
print_rank_0(f"Layerwise scores (1-indexed, higher is better): {layer_scores}")
13061308
assert sorted(layer_scores.keys()) == list(range(1, num_layers_hp.max + 1)) # type: ignore[arg-type]
13071309

1310+
return layer_scores
1311+
1312+
def _export_drop_layers(self) -> None:
1313+
"""Drop layers during export if num_layers hparam is set to a smaller value during pruning."""
1314+
num_layers_hp = self.get_hparam("num_layers")
1315+
if num_layers_hp.active == num_layers_hp.max: # no depth pruning
1316+
return
1317+
13081318
# sort layers by scores and drop the lowest ones
1319+
layer_scores = self._get_layer_scores()
13091320
sorted_layers = sorted(layer_scores.items(), key=lambda x: x[1], reverse=True)
13101321
layers_to_drop = [layer for layer, _ in sorted_layers[num_layers_hp.active :]] # type: ignore[misc]
13111322
drop_mcore_language_model_layers(self, layers_to_drop=layers_to_drop)
@@ -1337,6 +1348,47 @@ def freeze(self) -> None:
13371348
for layer in self.decoder.layers:
13381349
layer.freeze()
13391350

1351+
def get_activations_and_layer_scores(
1352+
self,
1353+
) -> tuple[list[dict[str, torch.Tensor]], dict[int, torch.Tensor]]:
1354+
"""Get the per-rank activations and layer scores from the module."""
1355+
local_activations = {}
1356+
for n, m in self.named_modules():
1357+
if hasattr(m, "_activations"):
1358+
local_activations[n] = m._activations
1359+
activations_per_rank = dist.allgather(
1360+
local_activations, group=get_pipeline_model_parallel_group()
1361+
)
1362+
assert len(activations_per_rank) == get_pipeline_model_parallel_world_size()
1363+
1364+
layer_scores = self._get_layer_scores()
1365+
1366+
return activations_per_rank, layer_scores
1367+
1368+
def set_activations_and_layer_scores(
1369+
self,
1370+
activations_per_rank: list[dict[str, torch.Tensor]],
1371+
layer_scores: dict[int, torch.Tensor],
1372+
) -> None:
1373+
"""Set the pre-computed layer_scores and per-rank activations instead of running forward.
1374+
1375+
Args:
1376+
layer_scores: Dict from layer_number (1-indexed) to score.
1377+
activations_per_rank: List of dicts from module name to activations. Should match PP size.
1378+
"""
1379+
rank = get_pipeline_model_parallel_rank()
1380+
pp_size = get_pipeline_model_parallel_world_size()
1381+
assert len(activations_per_rank) == pp_size, (
1382+
len(activations_per_rank),
1383+
activations_per_rank,
1384+
pp_size,
1385+
)
1386+
for layer in self.decoder.layers:
1387+
layer._scores = layer_scores[layer.layer_number]
1388+
for n, m in self.named_modules():
1389+
if hasattr(m, "_activations"):
1390+
m._activations = activations_per_rank[rank][n]
1391+
13401392

13411393
def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[int]) -> None:
13421394
"""Remove given layers (1-indexed) of the model (works with TP and/or PP).

modelopt/torch/opt/searcher.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from collections.abc import Callable
2828
from contextlib import nullcontext
2929
from typing import Any, final
30+
from warnings import warn
3031

3132
import numpy as np
3233
import pulp
@@ -239,7 +240,11 @@ def load_search_checkpoint(self) -> bool:
239240
"""Load function for search checkpoint returning indicator whether checkpoint was loaded."""
240241
# check if checkpoint exists
241242
checkpoint: str | None = self.config["checkpoint"]
242-
if checkpoint is None or not os.path.exists(checkpoint):
243+
if checkpoint is None:
244+
return False
245+
if not os.path.exists(checkpoint):
246+
if dist.is_master():
247+
warn(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.")
243248
return False
244249

245250
# iterate through state dict and load keys
@@ -250,14 +255,16 @@ def load_search_checkpoint(self) -> bool:
250255
setattr(self, key, state)
251256
return True
252257

253-
def save_search_checkpoint(self) -> None:
258+
def save_search_checkpoint(self, verbose=False) -> None:
254259
"""Save function for search checkpoint."""
255260
# check if save requirements are satisfied
256261
checkpoint: str | None = self.config["checkpoint"]
257262
if checkpoint is None or not dist.is_master():
258263
return
259264

260265
# save state dict
266+
if verbose:
267+
print(f"Saving searcher state to {checkpoint}...")
261268
save_dirname, _ = os.path.split(checkpoint)
262269
if save_dirname:
263270
os.makedirs(save_dirname, exist_ok=True)

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
Actual dynamic module implementations are at :mod:`modelopt.torch.nas.plugins.megatron`.
2525
"""
2626

27+
import copy
28+
2729
import torch
2830
from pydantic import create_model
2931

3032
# isort: off
3133
# import nas plugin to check if it is enabled else raises an Exception
3234
from modelopt.torch.nas.plugins.megatron import * # noqa: F403
33-
from modelopt.torch.nas.plugins.megatron import HAS_MAMBA
35+
from modelopt.torch.nas.plugins.megatron import HAS_MAMBA, _DynamicMCoreLanguageModel
3436
# isort: on
3537

3638
from modelopt.torch.nas.conversion import NASModeRegistry
@@ -60,22 +62,29 @@
6062
class MCoreMinitronSearcher(BaseSearcher):
6163
"""Searcher for Minitron pruning algorithm."""
6264

65+
activations_per_rank: list[dict[str, torch.Tensor]]
66+
layer_scores: dict[int, torch.Tensor]
67+
6368
@property
6469
def default_search_config(self) -> SearchConfig:
6570
"""Get the default config for the searcher."""
66-
return {**super().default_search_config, "max_iter_data_loader": 1024}
71+
return {
72+
**super().default_search_config,
73+
"max_iter_data_loader": 1024,
74+
"skip_sorting": False,
75+
"scores_path": None,
76+
}
6777

6878
@property
6979
def default_state_dict(self) -> SearchStateDict:
70-
"""Return default state dict."""
71-
return {} # Not used
80+
"""Return default state dict for importance scores and activations from forward loop."""
81+
return {"activations_per_rank": [], "layer_scores": {}}
7282

7383
def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
7484
"""Sanitize the search config dict."""
7585
config = super().sanitize_search_config(config)
76-
assert config["data_loader"] or config["forward_loop"], (
77-
"Data loader or forward loop must be provided for importance estimation!"
78-
)
86+
config["checkpoint"] = config["scores_path"]
87+
config["verbose"] = True # Print for all ranks
7988
return config
8089

8190
def before_search(self) -> None:
@@ -87,10 +96,11 @@ def before_search(self) -> None:
8796
"Only `export_config` constraint is supported for pruning!"
8897
)
8998

99+
self.constraints["export_config"] = copy.deepcopy(self.constraints["export_config"])
90100
export_config = self.constraints["export_config"]
91101
assert isinstance(export_config, dict) # to keep mypy happy
92102
assert export_config.keys() <= SUPPORTED_HPARAMS, (
93-
f"Only {SUPPORTED_HPARAMS} are supported for pruning!"
103+
f"Only {SUPPORTED_HPARAMS} are supported for pruning! Received: {export_config.keys()}"
94104
)
95105

96106
assert ("num_attention_heads" in export_config and "num_query_groups" in export_config) or (
@@ -124,14 +134,37 @@ def before_search(self) -> None:
124134
def run_search(self) -> None:
125135
"""Run actual search."""
126136
# Run forward loop to collect activations and sort parameters
127-
assert self.forward_loop is not None
128-
is_training = self.model.training
129-
self.model.eval()
130-
print_rank_0("Running forward loop...")
131-
with torch.no_grad():
132-
self.forward_loop(self.model)
133-
sort_parameters(self.model, self.hps_to_sort, verbose=True)
134-
self.model.train(is_training)
137+
unwrapped_model = self.model
138+
for m in self.model.modules():
139+
if isinstance(m, _DynamicMCoreLanguageModel):
140+
unwrapped_model = m
141+
break
142+
assert isinstance(unwrapped_model, _DynamicMCoreLanguageModel), "Model not supported!"
143+
144+
if self.layer_scores and self.activations_per_rank: # Available from checkpoint
145+
print_rank_0("Loading activations and scores per rank from checkpoint...")
146+
unwrapped_model.set_activations_and_layer_scores(
147+
self.activations_per_rank, self.layer_scores
148+
)
149+
elif not self.config["skip_sorting"]:
150+
print_rank_0("Running forward loop...")
151+
assert self.forward_loop is not None
152+
is_training = self.model.training
153+
self.model.eval()
154+
with torch.no_grad():
155+
self.forward_loop(self.model)
156+
self.model.train(is_training)
157+
158+
# Store activations and layer scores for re-pruning with different export configs
159+
self.activations_per_rank, self.layer_scores = (
160+
unwrapped_model.get_activations_and_layer_scores()
161+
)
162+
self.save_search_checkpoint(verbose=True)
163+
164+
if self.config["skip_sorting"]:
165+
print_rank_0("Skipping sorting parameters...")
166+
else:
167+
sort_parameters(self.model, self.hps_to_sort, verbose=True)
135168

136169
# Prune homogeneously
137170
export_config = self.constraints["export_config"]

modelopt/torch/utils/distributed.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,23 +124,23 @@ def broadcast(obj: Any, src: int = 0, group=None) -> Any:
124124
return obj
125125

126126

127-
def _allgather(tensors: list[torch.Tensor], tensor: torch.Tensor) -> None:
127+
def _allgather(tensors: list[torch.Tensor], tensor: torch.Tensor, group=None) -> None:
128128
if backend() == "torch":
129-
torch.distributed.all_gather(tensors, tensor)
129+
torch.distributed.all_gather(tensors, tensor, group)
130130

131131

132-
def allgather(obj: Any) -> list[Any]:
132+
def allgather(obj: Any, group=None) -> list[Any]:
133133
"""Gathers an object from all processes into a list."""
134-
if size() == 1:
134+
if size(group) == 1:
135135
return [obj]
136136

137137
# serialize
138138
tensor = _serialize(obj).cuda()
139139

140140
# gather the tensor size
141141
tensor_size = torch.LongTensor([tensor.numel()]).cuda()
142-
tensor_sizes = [torch.LongTensor([0]).cuda() for _ in range(size())]
143-
_allgather(tensor_sizes, tensor_size)
142+
tensor_sizes = [torch.LongTensor([0]).cuda() for _ in range(size(group))]
143+
_allgather(tensor_sizes, tensor_size, group)
144144
tensor_sizes = [int(tensor_size.item()) for tensor_size in tensor_sizes]
145145
max_size = max(tensor_sizes)
146146

@@ -149,7 +149,7 @@ def allgather(obj: Any) -> list[Any]:
149149
if tensor_size != max_size:
150150
padding = torch.ByteTensor(size=(max_size - tensor_size,)).cuda()
151151
tensor = torch.cat((tensor, padding), dim=0)
152-
_allgather(tensors, tensor)
152+
_allgather(tensors, tensor, group)
153153

154154
# deserialize
155155
objs = []
@@ -159,9 +159,9 @@ def allgather(obj: Any) -> list[Any]:
159159
return objs
160160

161161

162-
def allreduce(obj: Any, reduction: str = "sum") -> Any:
162+
def allreduce(obj: Any, reduction: str = "sum", group=None) -> Any:
163163
"""Reduces an object from all processes."""
164-
objs = allgather(obj)
164+
objs = allgather(obj, group)
165165
if reduction == "sum":
166166
return sum(objs)
167167
else:

0 commit comments

Comments
 (0)