Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitlab/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ unit:
##### GPU Tests #####
.multi-gpu-tests-default:
extends: .tests-default
timeout: 60m
timeout: 90m
image: nvcr.io/nvidia/pytorch:25.06-py3
variables:
GIT_DEPTH: 1000 # For correct version for tests/gpu/torch/quantization/plugins/test_megatron.py
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Model Optimizer Changelog (Linux)
- ``high_precision_dtype`` default to fp16 in ONNX quantization, i.e. quantized output model weights are now FP16 by default.
- Upgrade TensorRT-LLM dependency to 1.1.0rc2.
- Support Phi-4-multimodal and Qwen2.5-VL quantized HF checkpoint export in ``examples/vlm_ptq``.
- Support storing and restoring Minitron pruning activations and scores for re-pruning without running the forward loop again.
- Add Minitron pruning example for Megatron-LM framework. See ``examples/megatron-lm`` for more details.

0.35 (2025-09-04)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/guides/3_pruning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Pruning

.. tip::

Checkout `Llama 3.1 NeMo Minitron Pruning <https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/llama/pruning-distillation>`_ and
Checkout `Qwen 3 NeMo Minitron Pruning & Distillation <https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation>`_ and
`ResNet20 on CIFAR-10 Notebook <https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/examples/pruning/cifar_resnet.ipynb>`_
for an end-to-end example of pruning.

Expand Down
2 changes: 1 addition & 1 deletion examples/llm_distill/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Loss balancers:

Checkout the stand-alone distillation script in the [NVIDIA NeMo repository](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html).

You can also look at the tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/llama/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Llama 3.1 8B step-by-step in NeMo framework.
You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.

## Knowledge Distillation (KD) for HuggingFace Models

Expand Down
14 changes: 9 additions & 5 deletions examples/pruning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,27 @@ def forward_loop(model):
evaluate_and_print_results(model, ...)


# Specify the pruning constraints
# Specify the pruning constraints (Check Support Matrix for available pruning dimensions)
export_config = {
"hidden_size": 3072,
"ffn_hidden_size": 9216,
}


# Run the pruning process
mtp.prune(
# Save minitron scores at scores_path so we can re-run pruning with different export configs without running the forward loop again
# NOTE: Skip scores_path on re-running if you want to change the dataset and re-calibrate
model, pruning_scores = mtp.prune(
model,
mode="mcore_minitron",
constraints={"export_config": export_config},
dummy_input=None, # Not used
config={"forward_loop": forward_loop},
config={"forward_loop": forward_loop, "scores_path": "modelopt_minitron_scores.pth"},
)
```

If your model parameters are already sorted, you can skip the sorting step by setting `"skip_sorting": True` in `config` instead of passing `forward_loop`.

> [!Note]
> Fine-tuning / distillation is required after pruning to recover the accuracy. Please refer to pruning [fine-tuning](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-fine-tuning) for more details.

Expand All @@ -91,11 +95,11 @@ mtp.prune(

## Examples

### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Llama 3.1, Nemotron Nano)
### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Qwen 3, Nemotron Nano)

Checkout the Minitron pruning example for the [Megatron-LM Framework](../megatron-lm/README.md#-pruning) and [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/pruning/pruning.html) which showcases the usage of the powerful Minitron pruning algorithm developed by NVIDIA Research for pruning LLMs like Llama 3.1 8B, Qwen 3 8B, Nemotron Nano 12B v2, etc.

You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/llama/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Llama 3.1 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.
You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.

Some of the models pruned using Minitron method followed by distillation and post-training are:

Expand Down
70 changes: 61 additions & 9 deletions modelopt/torch/nas/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,10 @@ def _layer_imp_forward_hook(self, module, args, kwargs, output) -> None:
if hidden_states.shape[-1] != self.max_hidden_size:
return

# use full precision to avoid overflow
hidden_states = hidden_states.to(torch.float32)
output = output.to(torch.float32)

with torch.no_grad():
# Lower cosine_similarity means higher importance hence use 1 - cosine_similarity
score = 1 - F.cosine_similarity(hidden_states, output, dim=2).mean()
Expand Down Expand Up @@ -1234,10 +1238,10 @@ def _emb_layernorm_forward_hook(self, module, input, output) -> None:
output = output.to(torch.float32) # use full precision to avoid overflow
activations = output.abs().mean(dim=0) # [batch_size, hidden_size]
activations = activations.pow(2).sum(dim=0) # [hidden_size]
if module not in self._activations:
self._activations[module] = activations
if id(module) not in self._activations:
self._activations[id(module)] = activations
else:
self._activations[module] += activations
self._activations[id(module)] += activations

def _estimate_hidden_size_importance(self) -> TracedHp.Importance:
"""Return the activation magnitude-based importance of the hidden_size."""
Expand Down Expand Up @@ -1284,16 +1288,14 @@ def modify(
mamba_head_dim_divisor=mamba_head_dim_divisor,
)

def _export_drop_layers(self) -> None:
"""Drop layers during export if num_layers hparam is set to a smaller value during pruning."""
def _get_layer_scores(self) -> dict[int, torch.Tensor]:
"""Get the layer scores (1-indexed) from the module."""
num_layers_hp = self.get_hparam("num_layers")
if num_layers_hp.active == num_layers_hp.max: # no depth pruning
return

for layer in self.decoder.layers:
assert layer._scores > 0, "No scores collected for importance estimation."

# gather layer scores from all TP regions
# gather layer scores from all PP ranks
layer_scores = {}
for layer in self.decoder.layers:
layer_scores[layer.layer_number] = layer._scores
Expand All @@ -1302,10 +1304,19 @@ def _export_drop_layers(self) -> None:
all_pp_layer_scores, layer_scores, group=get_pipeline_model_parallel_group()
)
layer_scores = {k: v for d in all_pp_layer_scores for k, v in d.items()} # type: ignore[attr-defined]
print_rank_0(f"Layerwise scores for depth pruning: {layer_scores}")
print_rank_0(f"Layerwise scores (1-indexed, higher is better): {layer_scores}")
assert sorted(layer_scores.keys()) == list(range(1, num_layers_hp.max + 1)) # type: ignore[arg-type]

return layer_scores

def _export_drop_layers(self) -> None:
"""Drop layers during export if num_layers hparam is set to a smaller value during pruning."""
num_layers_hp = self.get_hparam("num_layers")
if num_layers_hp.active == num_layers_hp.max: # no depth pruning
return

# sort layers by scores and drop the lowest ones
layer_scores = self._get_layer_scores()
sorted_layers = sorted(layer_scores.items(), key=lambda x: x[1], reverse=True)
layers_to_drop = [layer for layer, _ in sorted_layers[num_layers_hp.active :]] # type: ignore[misc]
drop_mcore_language_model_layers(self, layers_to_drop=layers_to_drop)
Expand Down Expand Up @@ -1337,6 +1348,47 @@ def freeze(self) -> None:
for layer in self.decoder.layers:
layer.freeze()

def get_activations_and_layer_scores(
self,
) -> tuple[list[dict[str, torch.Tensor]], dict[int, torch.Tensor]]:
"""Get the per-rank activations and layer scores from the module."""
local_activations = {}
for n, m in self.named_modules():
if hasattr(m, "_activations"):
local_activations[n] = m._activations
activations_per_rank = dist.allgather(
local_activations, group=get_pipeline_model_parallel_group()
)
assert len(activations_per_rank) == get_pipeline_model_parallel_world_size()

layer_scores = self._get_layer_scores()

return activations_per_rank, layer_scores

def set_activations_and_layer_scores(
self,
activations_per_rank: list[dict[str, torch.Tensor]],
layer_scores: dict[int, torch.Tensor],
) -> None:
"""Set the pre-computed layer_scores and per-rank activations instead of running forward.

Args:
layer_scores: Dict from layer_number (1-indexed) to score.
activations_per_rank: List of dicts from module name to activations. Should match PP size.
"""
rank = get_pipeline_model_parallel_rank()
pp_size = get_pipeline_model_parallel_world_size()
assert len(activations_per_rank) == pp_size, (
len(activations_per_rank),
activations_per_rank,
pp_size,
)
for layer in self.decoder.layers:
layer._scores = layer_scores[layer.layer_number]
for n, m in self.named_modules():
if hasattr(m, "_activations"):
m._activations = activations_per_rank[rank][n]


def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[int]) -> None:
"""Remove given layers (1-indexed) of the model (works with TP and/or PP).
Expand Down
11 changes: 9 additions & 2 deletions modelopt/torch/opt/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from collections.abc import Callable
from contextlib import nullcontext
from typing import Any, final
from warnings import warn

import numpy as np
import pulp
Expand Down Expand Up @@ -239,7 +240,11 @@ def load_search_checkpoint(self) -> bool:
"""Load function for search checkpoint returning indicator whether checkpoint was loaded."""
# check if checkpoint exists
checkpoint: str | None = self.config["checkpoint"]
if checkpoint is None or not os.path.exists(checkpoint):
if checkpoint is None:
return False
if not os.path.exists(checkpoint):
if dist.is_master():
warn(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.")
return False

# iterate through state dict and load keys
Expand All @@ -250,14 +255,16 @@ def load_search_checkpoint(self) -> bool:
setattr(self, key, state)
return True

def save_search_checkpoint(self) -> None:
def save_search_checkpoint(self, verbose=False) -> None:
"""Save function for search checkpoint."""
# check if save requirements are satisfied
checkpoint: str | None = self.config["checkpoint"]
if checkpoint is None or not dist.is_master():
return

# save state dict
if verbose:
print(f"Saving searcher state to {checkpoint}...")
save_dirname, _ = os.path.split(checkpoint)
if save_dirname:
os.makedirs(save_dirname, exist_ok=True)
Expand Down
65 changes: 49 additions & 16 deletions modelopt/torch/prune/plugins/mcore_minitron.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
Actual dynamic module implementations are at :mod:`modelopt.torch.nas.plugins.megatron`.
"""

import copy

import torch
from pydantic import create_model

# isort: off
# import nas plugin to check if it is enabled else raises an Exception
from modelopt.torch.nas.plugins.megatron import * # noqa: F403
from modelopt.torch.nas.plugins.megatron import HAS_MAMBA
from modelopt.torch.nas.plugins.megatron import HAS_MAMBA, _DynamicMCoreLanguageModel
# isort: on

from modelopt.torch.nas.conversion import NASModeRegistry
Expand Down Expand Up @@ -60,22 +62,29 @@
class MCoreMinitronSearcher(BaseSearcher):
"""Searcher for Minitron pruning algorithm."""

activations_per_rank: list[dict[str, torch.Tensor]]
layer_scores: dict[int, torch.Tensor]

@property
def default_search_config(self) -> SearchConfig:
"""Get the default config for the searcher."""
return {**super().default_search_config, "max_iter_data_loader": 1024}
return {
**super().default_search_config,
"max_iter_data_loader": 1024,
"skip_sorting": False,
"scores_path": None,
}

@property
def default_state_dict(self) -> SearchStateDict:
"""Return default state dict."""
return {} # Not used
"""Return default state dict for importance scores and activations from forward loop."""
return {"activations_per_rank": [], "layer_scores": {}}

def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
"""Sanitize the search config dict."""
config = super().sanitize_search_config(config)
assert config["data_loader"] or config["forward_loop"], (
"Data loader or forward loop must be provided for importance estimation!"
)
config["checkpoint"] = config["scores_path"]
config["verbose"] = True # Print for all ranks
return config

def before_search(self) -> None:
Expand All @@ -87,10 +96,11 @@ def before_search(self) -> None:
"Only `export_config` constraint is supported for pruning!"
)

self.constraints["export_config"] = copy.deepcopy(self.constraints["export_config"])
export_config = self.constraints["export_config"]
assert isinstance(export_config, dict) # to keep mypy happy
assert export_config.keys() <= SUPPORTED_HPARAMS, (
f"Only {SUPPORTED_HPARAMS} are supported for pruning!"
f"Only {SUPPORTED_HPARAMS} are supported for pruning! Received: {export_config.keys()}"
)

assert ("num_attention_heads" in export_config and "num_query_groups" in export_config) or (
Expand Down Expand Up @@ -124,14 +134,37 @@ def before_search(self) -> None:
def run_search(self) -> None:
"""Run actual search."""
# Run forward loop to collect activations and sort parameters
assert self.forward_loop is not None
is_training = self.model.training
self.model.eval()
print_rank_0("Running forward loop...")
with torch.no_grad():
self.forward_loop(self.model)
sort_parameters(self.model, self.hps_to_sort, verbose=True)
self.model.train(is_training)
unwrapped_model = self.model
for m in self.model.modules():
if isinstance(m, _DynamicMCoreLanguageModel):
unwrapped_model = m
break
assert isinstance(unwrapped_model, _DynamicMCoreLanguageModel), "Model not supported!"

if self.layer_scores and self.activations_per_rank: # Available from checkpoint
print_rank_0("Loading activations and scores per rank from checkpoint...")
unwrapped_model.set_activations_and_layer_scores(
self.activations_per_rank, self.layer_scores
)
elif not self.config["skip_sorting"]:
print_rank_0("Running forward loop...")
assert self.forward_loop is not None
is_training = self.model.training
self.model.eval()
with torch.no_grad():
self.forward_loop(self.model)
self.model.train(is_training)

# Store activations and layer scores for re-pruning with different export configs
self.activations_per_rank, self.layer_scores = (
unwrapped_model.get_activations_and_layer_scores()
)
self.save_search_checkpoint(verbose=True)

if self.config["skip_sorting"]:
print_rank_0("Skipping sorting parameters...")
else:
sort_parameters(self.model, self.hps_to_sort, verbose=True)

# Prune homogeneously
export_config = self.constraints["export_config"]
Expand Down
18 changes: 9 additions & 9 deletions modelopt/torch/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,23 +124,23 @@ def broadcast(obj: Any, src: int = 0, group=None) -> Any:
return obj


def _allgather(tensors: list[torch.Tensor], tensor: torch.Tensor) -> None:
def _allgather(tensors: list[torch.Tensor], tensor: torch.Tensor, group=None) -> None:
if backend() == "torch":
torch.distributed.all_gather(tensors, tensor)
torch.distributed.all_gather(tensors, tensor, group)


def allgather(obj: Any) -> list[Any]:
def allgather(obj: Any, group=None) -> list[Any]:
"""Gathers an object from all processes into a list."""
if size() == 1:
if size(group) == 1:
return [obj]

# serialize
tensor = _serialize(obj).cuda()

# gather the tensor size
tensor_size = torch.LongTensor([tensor.numel()]).cuda()
tensor_sizes = [torch.LongTensor([0]).cuda() for _ in range(size())]
_allgather(tensor_sizes, tensor_size)
tensor_sizes = [torch.LongTensor([0]).cuda() for _ in range(size(group))]
_allgather(tensor_sizes, tensor_size, group)
tensor_sizes = [int(tensor_size.item()) for tensor_size in tensor_sizes]
max_size = max(tensor_sizes)

Expand All @@ -149,7 +149,7 @@ def allgather(obj: Any) -> list[Any]:
if tensor_size != max_size:
padding = torch.ByteTensor(size=(max_size - tensor_size,)).cuda()
tensor = torch.cat((tensor, padding), dim=0)
_allgather(tensors, tensor)
_allgather(tensors, tensor, group)

# deserialize
objs = []
Expand All @@ -159,9 +159,9 @@ def allgather(obj: Any) -> list[Any]:
return objs


def allreduce(obj: Any, reduction: str = "sum") -> Any:
def allreduce(obj: Any, reduction: str = "sum", group=None) -> Any:
"""Reduces an object from all processes."""
objs = allgather(obj)
objs = allgather(obj, group)
if reduction == "sum":
return sum(objs)
else:
Expand Down
Loading
Loading