Skip to content

Commit 86f0271

Browse files
authored
Update README.md
1 parent bb4a573 commit 86f0271

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ if DG.check_pruning_group(group): # avoid over-pruning, i.e., channels=0.
114114
# 4. Save & Load
115115
model.zero_grad() # clear gradients to avoid a large file size
116116
torch.save(model, 'model.pth') # !! no .state_dict here since the structure has been changed after pruning
117-
model = torch.load('model.pth') # load the pruned model
117+
model = torch.load('model.pth') # load the pruned model. you may need torch.load('model.pth', weights_only=False) for PyTorch 2.6.0+.
118118
```
119119
The above example shows the core algorithm, DepGraph, that captures the dependencies in structural pruning. The target layer `model.conv1` is coupled with multiple layers, necessitating their simultaneous removal in structural pruning. We can print the group to take a look at the internal dependencies. In the subsequent outputs, "A => B" indicates that pruning operation "A" triggers pruning operation "B." The first group[0] refers to the root of pruning. For more details about grouping, please refer to [Wiki - DepGraph & Group](https://github.com/VainF/Torch-Pruning/wiki/3.-DepGraph-&-Group).
120120

@@ -347,6 +347,8 @@ The following script saves the whole model object (structure+weights) as a 'mode
347347
model.zero_grad() # Remove gradients
348348
torch.save(model, 'model.pth') # without .state_dict
349349
model = torch.load('model.pth') # load the pruned model
350+
# For PyTorch 2.6.0+, you may need weights_only=False to enable model loading
351+
# model = torch.load('model.pth', weights_only=False)
350352
```
351353
352354
### Low-level Pruning Functions

0 commit comments

Comments
 (0)