Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/pruning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ model = GPTModel(
# For Megatron-LM framework, you can use the following utility function
from megatron.training.training import evaluate_and_print_results

def forward_loop(model):
evaluate_and_print_results(model, ...)
def forward_loop(_):
evaluate_and_print_results(prefix, forward_step, train_iterator, model, ...)


# Specify the pruning constraints (Check Support Matrix for available pruning dimensions)
Expand All @@ -66,7 +66,7 @@ export_config = {
}


# Run the pruning process
# Run the pruning process (if model is a list then pass model[0] to the prune API)
# Save minitron scores at scores_path so we can re-run pruning with different export configs without running the forward loop again
# NOTE: Skip scores_path on re-running if you want to change the dataset and re-calibrate
model, pruning_scores = mtp.prune(
Expand All @@ -81,7 +81,7 @@ model, pruning_scores = mtp.prune(
If your model parameters are already sorted, you can skip the sorting step by setting `"skip_sorting": True` in `config` instead of passing `forward_loop`.

> [!Note]
> Fine-tuning / distillation is required after pruning to recover the accuracy. Please refer to pruning [fine-tuning](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-fine-tuning) for more details.
> Fine-tuning / distillation is required after pruning to recover the accuracy. Please refer to [end-to-end pruning and distillation tutorial](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) for more details.

## Support Matrix

Expand Down
6 changes: 4 additions & 2 deletions modelopt/torch/nas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
NASModeRegistry = _ModeRegistryCls("nas")


def convert(model: ModelLike, mode: ModeLike) -> nn.Module:
def convert(
model: ModelLike, mode: ModeLike, registry: _ModeRegistryCls = NASModeRegistry
) -> nn.Module:
"""Convert a regular PyTorch model into a model that supports design space optimization.

Args:
Expand Down Expand Up @@ -84,7 +86,7 @@ def convert(model: ModelLike, mode: ModeLike) -> nn.Module:
#. Use ``*`` as a wildcard matching any layer.
"""
# apply mode and handle model-like object with wrapper
return apply_mode(model, mode, registry=NASModeRegistry)
return apply_mode(model, mode, registry=registry)


def export(model: nn.Module, strict: bool = True, calib: bool = False) -> nn.Module:
Expand Down
5 changes: 2 additions & 3 deletions modelopt/torch/prune/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from torch import nn

import modelopt.torch.nas as mtn
from modelopt.torch.opt.conversion import apply_mode
from modelopt.torch.opt.mode import ModeLike, _ModeRegistryCls
from modelopt.torch.opt.searcher import ConstraintsDict, SearchConfig

Expand Down Expand Up @@ -199,8 +198,8 @@ def prune(
search algorithm. The returned subnet is thus a reference to the same model instance as the
input model.
"""
# apply prune mode(s) to model
model = apply_mode(model, mode, registry=PruneModeRegistry)
# apply prune mode(s) to model and convert it to DynamicModule
model = mtn.convert(model, mode, registry=PruneModeRegistry)

# now run the search and return the result
return mtn.search(
Expand Down