You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<imgsrc="https://colab.research.google.com/assets/colab-badge.svg"alt="Open In Colab"/>
17
17
</a>
@@ -33,6 +33,7 @@ For more technical details, please refer to our CVPR'23 paper.
33
33
34
34
35
35
### Update:
36
+
- 🔥 2025.03.24 Examples for pruning [**DeepSeek-R1-Distill**](https://github.com/VainF/Torch-Pruning/tree/master/examples/LLMs).
36
37
- 🔥 2024.11.17 We are working to add more [**examples for LLMs**](https://github.com/VainF/Torch-Pruning/tree/master/examples/LLMs), such as Llama-2/3, Phi-3, Qwen-2/2.5.
37
38
- 🔥 2024.09.27 Check our latest work, [**MaskLLM (NeurIPS 24 Spotlight)**](https://github.com/NVlabs/MaskLLM), for learnable semi-structured sparsity of LLMs.
38
39
- 🔥 2024.07.20 Add [**Isomorphic Pruning (ECCV'24)**](https://arxiv.org/abs/2407.04616). A SOTA method for Vision Transformers and Modern CNNs.
@@ -45,7 +46,7 @@ Or Join our WeChat group for more discussions: ✉️ [Group-2](https://github.c
45
46
-[Installation](#installation)
46
47
-[Quickstart](#quickstart)
47
48
-[Why Torch-Pruning?](#why-torch-pruning)
48
-
-[A Minimal Example of DepGraph](#a-minimal-example-of-depgraph)
49
+
-[How It Works: DepGraph](#how-it-works-depgraph)
49
50
-[High-level Pruners](#high-level-pruners)
50
51
-[Global Pruning and Isomorphic Pruning](#global-pruning-and-isomorphic-pruning)
51
52
-[Pruning Ratios](#pruning-ratios)
@@ -88,7 +89,7 @@ In structural pruning, the removal of a single parameter may affect multiple lay
88
89
<imgsrc="assets/dep.png"width="100%">
89
90
</div>
90
91
91
-
### A Minimal Example of DepGraph
92
+
### How It Works: DepGraph
92
93
93
94
> [!IMPORTANT]
94
95
> Please make sure that AutoGrad is enabled since TP will analyze the model structure with the Pytorch AutoGrad. This means we need to remove ``torch.no_grad()`` or something similar when building the dependency graph.
model.zero_grad() # clear gradients to avoid a large file size
115
116
torch.save(model, 'model.pth') # !! no .state_dict here since the structure has been changed after pruning
116
-
model = torch.load('model.pth') # load the pruned model
117
+
model = torch.load('model.pth') # load the pruned model. you may need torch.load('model.pth', weights_only=False) for PyTorch 2.6.0+.
117
118
```
118
-
The above example shows the basic pruning pipeline using DepGraph. The target layer `model.conv1` is coupled with multiple layers, necessitating their simultaneous removal in structural pruning. We can print the group to take a look at the internal dependencies. In the subsequent outputs, "A => B" indicates that pruning operation "A" triggers pruning operation "B." The first group[0] refers to the root of pruning. For more details about grouping, please refer to [Wiki - DepGraph & Group](https://github.com/VainF/Torch-Pruning/wiki/3.-DepGraph-&-Group).
119
+
The above example shows the core algorithm, DepGraph, that captures the dependencies in structural pruning. The target layer `model.conv1` is coupled with multiple layers, necessitating their simultaneous removal in structural pruning. We can print the group to take a look at the internal dependencies. In the subsequent outputs, "A => B" indicates that pruning operation "A" triggers pruning operation "B." The first group[0] refers to the root of pruning. For more details about grouping, please refer to [Wiki - DepGraph & Group](https://github.com/VainF/Torch-Pruning/wiki/3.-DepGraph-&-Group).
119
120
120
121
```python
121
122
print(group.details()) # or print(group)
@@ -156,6 +157,9 @@ for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[
156
157
157
158
### High-level Pruners
158
159
160
+
> [!NOTE]
161
+
> **The pruning ratio**: In TP, the ``pruning_ratio`` refers to the pruning ratio of channels/dims. Since both in & out dims will be removed by $p$, the actual ``parameter_pruning_ratio`` of will be roughly $1-(1-p)^2$. To remove 50% of parameters, you may use ``pruning_ratio=0.30`` instead, which leads to the actual parameter pruning ratio of `$1-(1-0.3)^2=0.51$ (51% parameters removed).
162
+
159
163
With DepGraph, we developed several high-level pruners to facilitate effortless pruning. By specifying the desired channel pruning ratio, the pruner will scan all prunable groups, estimate weight importance and perform pruning. You can fine-tune the remaining weights using your own training code. For detailed information on this process, please refer to [this tutorial](https://github.com/VainF/Torch-Pruning/blob/master/examples/notebook/1%20-%20Customize%20Your%20Own%20Pruners.ipynb), which shows how to implement a [Network Slimming (ICCV 2017)](https://arxiv.org/abs/1708.06519) pruner from scratch. Additionally, a more practical example is available in [VainF/Isomorphic-Pruning](https://github.com/VainF/Isomorphic-Pruning) for ViT and ConvNext pruning.
160
164
161
165
```python
@@ -167,15 +171,15 @@ model = resnet18(pretrained=True)
167
171
example_inputs = torch.randn(1, 3, 224, 224)
168
172
169
173
# 1. Importance criterion, here we calculate the L2 Norm of grouped weights as the importance score
170
-
imp = tp.importance.GroupNormImportance(p=2)
174
+
imp = tp.importance.GroupMagnitudeImportance(p=2)
171
175
172
176
# 2. Initialize a pruner with the model and the importance criterion
173
177
ignored_layers = []
174
178
for m in model.modules():
175
179
ifisinstance(m, torch.nn.Linear) and m.out_features ==1000:
176
180
ignored_layers.append(m) # DO NOT prune the final classifier!
177
181
178
-
pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
182
+
pruner = tp.pruner.BasePruner( # We can always choose BasePruner if sparse training is not required.
179
183
model,
180
184
example_inputs,
181
185
importance=imp,
@@ -187,23 +191,53 @@ pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse train
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")
193
199
200
+
194
201
# 4. finetune the pruned model using your own code.
195
202
# finetune(model)
196
203
# ...
197
204
```
198
205
199
-
> [!NOTE]
200
-
> **About the pruning ratio**: In TP, the ``pruning_ratio`` refers to the pruning ratio of channels/dims. Since both in & out dims will be removed by p%, the actual ``parameter_pruning_ratio`` of will be roughly 1-(1-p%)^2. To remove 50% of parameters, you may use ``pruning_ratio=0.30`` instead, which yields a ``parameter_pruning_ratio`` of 1-(1-0.3)^2=0.51 (51% parameters removed).
206
+
<details>
207
+
<summary>Output</summary>
208
+
209
+
The model difference before and after pruning will be highlighted by something like `(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) => (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)`.
MACs: 1.822177768 G -> 0.487202536 G, #Params: 11.689512 M -> 3.05588 M
230
+
```
231
+
</details>
232
+
233
+
234
+
201
235
202
236
#### Global Pruning and Isomorphic Pruning
203
237
Global pruning performs importance ranking on all layers, which has the potential to find better structures. This can be easily achieved by setting ``global_pruning=True`` in the pruner. While this strategy can possibly offer performance advantages, it also carries the potential of overly pruning specific layers, resulting in a substantial decline in overall performance. We provide an alternative algorithm called [Isomorphic Pruning](https://arxiv.org/abs/2407.04616) to alleviate this issue, which can be enabled with ``isomorphic=True``. Comprehensive examples for ViT & ConvNext pruning are available in [this project](https://github.com/VainF/Isomorphic-Pruning).
204
238
205
239
```python
206
-
pruner = tp.pruner.MetaPruner(
240
+
pruner = tp.pruner.BasePruner(
207
241
...
208
242
isomorphic=True, # enable isomorphic pruning to improve global ranking
The argument ``pruning_ratio`` detemines the default pruning ratio. If you want to customize the pruning ratio for some layers or blocks, you can use ``pruning_ratio_dict``. The key of the dict can be a single ``nn.Module`` or a tuple of ``nn.Module``. In the second case, all modules in the tuple will form a ``scope`` and share the user-defined pruning ratio and compete to be pruned.
220
254
```python
221
-
pruner = tp.pruner.MetaPruner(
255
+
pruner = tp.pruner.BasePruner(
222
256
...
223
257
global_pruning=True,
224
258
pruning_ratio=0.5, # default pruning ratio
@@ -313,6 +347,8 @@ The following script saves the whole model object (structure+weights) as a 'mode
313
347
model.zero_grad() # Remove gradients
314
348
torch.save(model, 'model.pth') # without .state_dict
315
349
model = torch.load('model.pth') # load the pruned model
350
+
# For PyTorch 2.6.0+, you may need weights_only=False to enable model loading
351
+
# model = torch.load('model.pth', weights_only=False)
0 commit comments