Skip to content

Commit d4af5ff

Browse files
authored
Merge pull request #472 from VainF/v2.0
V2.0
2 parents c122b85 + 86f0271 commit d4af5ff

34 files changed

+1056
-902
lines changed

README.md

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
<a href="https://pytorch.org/"><img src="https://img.shields.io/badge/PyTorch-1.x %20%7C%202.x-673ab7.svg" alt="Tested PyTorch Versions"></a>
1212
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-4caf50.svg" alt="License"></a>
1313
<a href="https://pepy.tech/project/Torch-Pruning"><img src="https://static.pepy.tech/badge/Torch-Pruning?color=2196f3" alt="Downloads"></a>
14-
<a href="https://github.com/VainF/Torch-Pruning/releases/latest"><img src="https://img.shields.io/badge/Latest%20Version-1.5.1-3f51b5.svg" alt="Latest Version"></a>
14+
<a href="https://github.com/VainF/Torch-Pruning/releases/latest"><img src="https://img.shields.io/badge/Latest%20Version-1.5.2-3f51b5.svg" alt="Latest Version"></a>
1515
<a href="https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing">
1616
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
1717
</a>
@@ -33,6 +33,7 @@ For more technical details, please refer to our CVPR'23 paper.
3333

3434

3535
### Update:
36+
- 🔥 2025.03.24 Examples for pruning [**DeepSeek-R1-Distill**](https://github.com/VainF/Torch-Pruning/tree/master/examples/LLMs).
3637
- 🔥 2024.11.17 We are working to add more [**examples for LLMs**](https://github.com/VainF/Torch-Pruning/tree/master/examples/LLMs), such as Llama-2/3, Phi-3, Qwen-2/2.5.
3738
- 🔥 2024.09.27 Check our latest work, [**MaskLLM (NeurIPS 24 Spotlight)**](https://github.com/NVlabs/MaskLLM), for learnable semi-structured sparsity of LLMs.
3839
- 🔥 2024.07.20 Add [**Isomorphic Pruning (ECCV'24)**](https://arxiv.org/abs/2407.04616). A SOTA method for Vision Transformers and Modern CNNs.
@@ -45,7 +46,7 @@ Or Join our WeChat group for more discussions: ✉️ [Group-2](https://github.c
4546
- [Installation](#installation)
4647
- [Quickstart](#quickstart)
4748
- [Why Torch-Pruning?](#why-torch-pruning)
48-
- [A Minimal Example of DepGraph](#a-minimal-example-of-depgraph)
49+
- [How It Works: DepGraph](#how-it-works-depgraph)
4950
- [High-level Pruners](#high-level-pruners)
5051
- [Global Pruning and Isomorphic Pruning](#global-pruning-and-isomorphic-pruning)
5152
- [Pruning Ratios](#pruning-ratios)
@@ -88,7 +89,7 @@ In structural pruning, the removal of a single parameter may affect multiple lay
8889
<img src="assets/dep.png" width="100%">
8990
</div>
9091

91-
### A Minimal Example of DepGraph
92+
### How It Works: DepGraph
9293

9394
> [!IMPORTANT]
9495
> Please make sure that AutoGrad is enabled since TP will analyze the model structure with the Pytorch AutoGrad. This means we need to remove ``torch.no_grad()`` or something similar when building the dependency graph.
@@ -113,9 +114,9 @@ if DG.check_pruning_group(group): # avoid over-pruning, i.e., channels=0.
113114
# 4. Save & Load
114115
model.zero_grad() # clear gradients to avoid a large file size
115116
torch.save(model, 'model.pth') # !! no .state_dict here since the structure has been changed after pruning
116-
model = torch.load('model.pth') # load the pruned model
117+
model = torch.load('model.pth') # load the pruned model. you may need torch.load('model.pth', weights_only=False) for PyTorch 2.6.0+.
117118
```
118-
The above example shows the basic pruning pipeline using DepGraph. The target layer `model.conv1` is coupled with multiple layers, necessitating their simultaneous removal in structural pruning. We can print the group to take a look at the internal dependencies. In the subsequent outputs, "A => B" indicates that pruning operation "A" triggers pruning operation "B." The first group[0] refers to the root of pruning. For more details about grouping, please refer to [Wiki - DepGraph & Group](https://github.com/VainF/Torch-Pruning/wiki/3.-DepGraph-&-Group).
119+
The above example shows the core algorithm, DepGraph, that captures the dependencies in structural pruning. The target layer `model.conv1` is coupled with multiple layers, necessitating their simultaneous removal in structural pruning. We can print the group to take a look at the internal dependencies. In the subsequent outputs, "A => B" indicates that pruning operation "A" triggers pruning operation "B." The first group[0] refers to the root of pruning. For more details about grouping, please refer to [Wiki - DepGraph & Group](https://github.com/VainF/Torch-Pruning/wiki/3.-DepGraph-&-Group).
119120

120121
```python
121122
print(group.details()) # or print(group)
@@ -156,6 +157,9 @@ for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[
156157

157158
### High-level Pruners
158159

160+
> [!NOTE]
161+
> **The pruning ratio**: In TP, the ``pruning_ratio`` refers to the pruning ratio of channels/dims. Since both in & out dims will be removed by $p$, the actual ``parameter_pruning_ratio`` of will be roughly $1-(1-p)^2$. To remove 50% of parameters, you may use ``pruning_ratio=0.30`` instead, which leads to the actual parameter pruning ratio of `$1-(1-0.3)^2=0.51$ (51% parameters removed).
162+
159163
With DepGraph, we developed several high-level pruners to facilitate effortless pruning. By specifying the desired channel pruning ratio, the pruner will scan all prunable groups, estimate weight importance and perform pruning. You can fine-tune the remaining weights using your own training code. For detailed information on this process, please refer to [this tutorial](https://github.com/VainF/Torch-Pruning/blob/master/examples/notebook/1%20-%20Customize%20Your%20Own%20Pruners.ipynb), which shows how to implement a [Network Slimming (ICCV 2017)](https://arxiv.org/abs/1708.06519) pruner from scratch. Additionally, a more practical example is available in [VainF/Isomorphic-Pruning](https://github.com/VainF/Isomorphic-Pruning) for ViT and ConvNext pruning.
160164

161165
```python
@@ -167,15 +171,15 @@ model = resnet18(pretrained=True)
167171
example_inputs = torch.randn(1, 3, 224, 224)
168172

169173
# 1. Importance criterion, here we calculate the L2 Norm of grouped weights as the importance score
170-
imp = tp.importance.GroupNormImportance(p=2)
174+
imp = tp.importance.GroupMagnitudeImportance(p=2)
171175

172176
# 2. Initialize a pruner with the model and the importance criterion
173177
ignored_layers = []
174178
for m in model.modules():
175179
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
176180
ignored_layers.append(m) # DO NOT prune the final classifier!
177181

178-
pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
182+
pruner = tp.pruner.BasePruner( # We can always choose BasePruner if sparse training is not required.
179183
model,
180184
example_inputs,
181185
importance=imp,
@@ -187,23 +191,53 @@ pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse train
187191

188192
# 3. Prune the model
189193
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
194+
tp.utils.print_tool.before_pruning(model) # or print(model)
190195
pruner.step()
196+
tp.utils.print_tool.after_pruning(model) # or print(model), this util will show the difference before and after pruning
191197
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
192198
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")
193199

200+
194201
# 4. finetune the pruned model using your own code.
195202
# finetune(model)
196203
# ...
197204
```
198205

199-
> [!NOTE]
200-
> **About the pruning ratio**: In TP, the ``pruning_ratio`` refers to the pruning ratio of channels/dims. Since both in & out dims will be removed by p%, the actual ``parameter_pruning_ratio`` of will be roughly 1-(1-p%)^2. To remove 50% of parameters, you may use ``pruning_ratio=0.30`` instead, which yields a ``parameter_pruning_ratio`` of 1-(1-0.3)^2=0.51 (51% parameters removed).
206+
<details>
207+
<summary>Output</summary>
208+
209+
The model difference before and after pruning will be highlighted by something like `(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) => (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)`.
210+
```
211+
ResNet(
212+
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) => (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
213+
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) => (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
214+
(relu): ReLU(inplace=True)
215+
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
216+
...
217+
(1): BasicBlock(
218+
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) => (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
219+
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) => (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
220+
(relu): ReLU(inplace=True)
221+
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) => (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
222+
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) => (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
223+
)
224+
)
225+
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
226+
(fc): Linear(in_features=512, out_features=1000, bias=True) => (fc): Linear(in_features=256, out_features=1000, bias=True)
227+
)
228+
229+
MACs: 1.822177768 G -> 0.487202536 G, #Params: 11.689512 M -> 3.05588 M
230+
```
231+
</details>
232+
233+
234+
201235

202236
#### Global Pruning and Isomorphic Pruning
203237
Global pruning performs importance ranking on all layers, which has the potential to find better structures. This can be easily achieved by setting ``global_pruning=True`` in the pruner. While this strategy can possibly offer performance advantages, it also carries the potential of overly pruning specific layers, resulting in a substantial decline in overall performance. We provide an alternative algorithm called [Isomorphic Pruning](https://arxiv.org/abs/2407.04616) to alleviate this issue, which can be enabled with ``isomorphic=True``. Comprehensive examples for ViT & ConvNext pruning are available in [this project](https://github.com/VainF/Isomorphic-Pruning).
204238

205239
```python
206-
pruner = tp.pruner.MetaPruner(
240+
pruner = tp.pruner.BasePruner(
207241
...
208242
isomorphic=True, # enable isomorphic pruning to improve global ranking
209243
global_pruning=True, # global pruning
@@ -218,7 +252,7 @@ pruner = tp.pruner.MetaPruner(
218252

219253
The argument ``pruning_ratio`` detemines the default pruning ratio. If you want to customize the pruning ratio for some layers or blocks, you can use ``pruning_ratio_dict``. The key of the dict can be a single ``nn.Module`` or a tuple of ``nn.Module``. In the second case, all modules in the tuple will form a ``scope`` and share the user-defined pruning ratio and compete to be pruned.
220254
```python
221-
pruner = tp.pruner.MetaPruner(
255+
pruner = tp.pruner.BasePruner(
222256
...
223257
global_pruning=True,
224258
pruning_ratio=0.5, # default pruning ratio
@@ -313,6 +347,8 @@ The following script saves the whole model object (structure+weights) as a 'mode
313347
model.zero_grad() # Remove gradients
314348
torch.save(model, 'model.pth') # without .state_dict
315349
model = torch.load('model.pth') # load the pruned model
350+
# For PyTorch 2.6.0+, you may need weights_only=False to enable model loading
351+
# model = torch.load('model.pth', weights_only=False)
316352
```
317353
318354
### Low-level Pruning Functions

README_CN.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Torch-Pruning (TP) 是一个用于结构化剪枝的库,具有以下特点:
4343

4444

4545
### **主要功能:**
46-
- [x] 高级剪枝器:[MetaPruner](torch_pruning/pruner/algorithms/metapruner.py)[MagnitudePruner](https://arxiv.org/abs/1608.08710)[BNScalePruner](https://arxiv.org/abs/1708.06519)[GroupNormPruner](https://arxiv.org/abs/2301.12900)[GrowingRegPruner](https://arxiv.org/abs/2012.09243)、RandomPruner等。可以在我们的 [wiki页面](https://github.com/VainF/Torch-Pruning/wiki/0.-Paper-List) 上找到相关论文列表。
46+
- [x] 高级剪枝器:[BasePruner](torch_pruning/pruner/algorithms/BasePruner.py)[MagnitudePruner](https://arxiv.org/abs/1608.08710)[BNScalePruner](https://arxiv.org/abs/1708.06519)[GroupNormPruner](https://arxiv.org/abs/2301.12900)[GrowingRegPruner](https://arxiv.org/abs/2012.09243)、RandomPruner等。可以在我们的 [wiki页面](https://github.com/VainF/Torch-Pruning/wiki/0.-Paper-List) 上找到相关论文列表。
4747
- [x] 自动化结构化剪枝的依赖图
4848
- [x] [低级剪枝函数](torch_pruning/pruner/function.py)
4949
- [x] 支持的重要性准则:L-p 范数、Taylor、Random、BNScaling等
@@ -179,15 +179,15 @@ model = resnet18(pretrained=True)
179179
example_inputs = torch.randn(1, 3, 224, 224)
180180

181181
# 1. Importance criterion
182-
imp = tp.importance.GroupTaylorImportance() # or GroupNormImportance(p=2), GroupHessianImportance(), etc.
182+
imp = tp.importance.GroupTaylorImportance() # or GroupMagnitudeImportance(p=2), GroupHessianImportance(), etc.
183183

184184
# 2. Initialize a pruner with the model and the importance criterion
185185
ignored_layers = []
186186
for m in model.modules():
187187
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
188188
ignored_layers.append(m) # DO NOT prune the final classifier!
189189

190-
pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
190+
pruner = tp.pruner.BasePruner( # We can always choose BasePruner if sparse training is not required.
191191
model,
192192
example_inputs,
193193
importance=imp,

examples/LLMs/prune_llm.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,10 @@ def main():
288288
##############
289289
# Pruning
290290
##############
291-
print("----------------- Before Pruning -----------------")
292-
print(model)
291+
import torch_pruning as tp
292+
tp.utils.print_tool.before_pruning(model)
293293
text = "Hello world."
294294
inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
295-
import torch_pruning as tp
296295
num_heads = {}
297296
out_channel_groups = {}
298297
seperate_qkv = False
@@ -313,12 +312,13 @@ def main():
313312
_is_gqa = model.config.num_attention_heads != model.config.num_key_value_heads
314313
head_pruning_ratio = args.pruning_ratio
315314
hidden_size_pruning_ratio = args.pruning_ratio
316-
importance = tp.importance.GroupNormImportance(p=2, group_reduction='mean') #tp.importance.ActivationImportance(p=2, target_types=[torch.nn.Linear])
317-
pruner = tp.pruner.MetaPruner(
315+
importance = tp.importance.GroupMagnitudeImportance(p=2, group_reduction='mean') #tp.importance.ActivationImportance(p=2, target_types=[torch.nn.Linear])
316+
pruner = tp.pruner.BasePruner(
318317
model,
319318
example_inputs=inputs,
320319
importance=importance,
321320
global_pruning=False,
321+
output_transform=lambda x: x.logits,
322322
pruning_ratio=hidden_size_pruning_ratio,
323323
ignored_layers=[model.lm_head],
324324
num_heads=num_heads,
@@ -356,7 +356,10 @@ def main():
356356
#m.head_dim = m.q_proj.out_features // m.num_heads
357357
if not _is_gqa:
358358
m.num_key_value_heads = m.num_heads
359-
m.num_key_value_groups = m.num_heads // m.num_key_value_heads
359+
model.config.num_key_value_heads = m.num_heads
360+
if hasattr(m, "num_key_value_groups"):
361+
m.num_key_value_groups = m.num_heads // model.config.num_key_value_heads
362+
360363
elif name.endswith("mlp"):
361364
if hasattr(m, "gate_proj"):
362365
m.hidden_size = m.gate_proj.in_features
@@ -369,8 +372,7 @@ def main():
369372

370373
if not _is_gqa:
371374
model.config.num_key_value_heads = model.config.num_attention_heads
372-
print("----------------- After Pruning -----------------")
373-
print(model)
375+
tp.utils.print_tool.after_pruning(model, do_print=True)
374376
print(model.config)
375377

376378

0 commit comments

Comments
 (0)