Skip to content

Commit cb44c55

Browse files
Minor pruning update (#397)
Signed-off-by: Keval Morabia <[email protected]>
1 parent fab4b41 commit cb44c55

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

examples/pruning/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ model = GPTModel(
5555
# For Megatron-LM framework, you can use the following utility function
5656
from megatron.training.training import evaluate_and_print_results
5757

58-
def forward_loop(model):
59-
evaluate_and_print_results(model, ...)
58+
def forward_loop(_):
59+
evaluate_and_print_results(prefix, forward_step, train_iterator, model, ...)
6060

6161

6262
# Specify the pruning constraints (Check Support Matrix for available pruning dimensions)
@@ -66,7 +66,7 @@ export_config = {
6666
}
6767

6868

69-
# Run the pruning process
69+
# Run the pruning process (if model is a list then pass model[0] to the prune API)
7070
# Save minitron scores at scores_path so we can re-run pruning with different export configs without running the forward loop again
7171
# NOTE: Skip scores_path on re-running if you want to change the dataset and re-calibrate
7272
model, pruning_scores = mtp.prune(
@@ -81,7 +81,7 @@ model, pruning_scores = mtp.prune(
8181
If your model parameters are already sorted, you can skip the sorting step by setting `"skip_sorting": True` in `config` instead of passing `forward_loop`.
8282

8383
> [!Note]
84-
> Fine-tuning / distillation is required after pruning to recover the accuracy. Please refer to pruning [fine-tuning](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/3_pruning.html#pruning-fine-tuning) for more details.
84+
> Fine-tuning / distillation is required after pruning to recover the accuracy. Please refer to [end-to-end pruning and distillation tutorial](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) for more details.
8585
8686
## Support Matrix
8787

modelopt/torch/nas/conversion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
NASModeRegistry = _ModeRegistryCls("nas")
2727

2828

29-
def convert(model: ModelLike, mode: ModeLike) -> nn.Module:
29+
def convert(
30+
model: ModelLike, mode: ModeLike, registry: _ModeRegistryCls = NASModeRegistry
31+
) -> nn.Module:
3032
"""Convert a regular PyTorch model into a model that supports design space optimization.
3133
3234
Args:
@@ -84,7 +86,7 @@ def convert(model: ModelLike, mode: ModeLike) -> nn.Module:
8486
#. Use ``*`` as a wildcard matching any layer.
8587
"""
8688
# apply mode and handle model-like object with wrapper
87-
return apply_mode(model, mode, registry=NASModeRegistry)
89+
return apply_mode(model, mode, registry=registry)
8890

8991

9092
def export(model: nn.Module, strict: bool = True, calib: bool = False) -> nn.Module:

modelopt/torch/prune/pruning.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from torch import nn
2121

2222
import modelopt.torch.nas as mtn
23-
from modelopt.torch.opt.conversion import apply_mode
2423
from modelopt.torch.opt.mode import ModeLike, _ModeRegistryCls
2524
from modelopt.torch.opt.searcher import ConstraintsDict, SearchConfig
2625

@@ -199,8 +198,8 @@ def prune(
199198
search algorithm. The returned subnet is thus a reference to the same model instance as the
200199
input model.
201200
"""
202-
# apply prune mode(s) to model
203-
model = apply_mode(model, mode, registry=PruneModeRegistry)
201+
# apply prune mode(s) to model and convert it to DynamicModule
202+
model = mtn.convert(model, mode, registry=PruneModeRegistry)
204203

205204
# now run the search and return the result
206205
return mtn.search(

0 commit comments

Comments
 (0)