Skip to content

Commit 5663afd

Browse files
[pt] rb pruning (#3688)
### Changes Add rb pruning algorithm Rename prune folder to pruning Add parameters to select pruning mode in example ### Related tickets 173791
1 parent 6496b96 commit 5663afd

File tree

30 files changed

+1022
-175
lines changed

30 files changed

+1022
-175
lines changed

examples/pruning/torch/resnet18/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,6 @@ It's pretty simple. The example does not require additional preparation. It will
3232

3333
```bash
3434
python main.py
35+
# Or to run Regularization-Based pruning
36+
python main.py --mode rb
3537
```

examples/pruning/torch/resnet18/main.py

Lines changed: 89 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -11,49 +11,54 @@
1111

1212
import os
1313
import warnings
14+
from argparse import ArgumentParser
1415
from pathlib import Path
1516

1617
import openvino as ov
1718
import torch
18-
import torch.optim
19-
import torch.utils.data
20-
import torch.utils.data.distributed
21-
import torchvision.datasets as datasets
22-
import torchvision.models as models
23-
import torchvision.transforms as transforms
2419
from fastdownload import FastDownload
2520
from rich.progress import track
2621
from torch import nn
2722
from torch.jit import TracerWarning
2823
from torch.utils.data import DataLoader
24+
from torchvision import datasets
25+
from torchvision import transforms
26+
from torchvision.models import resnet18
2927

3028
import nncf
31-
import nncf.parameters
32-
import nncf.torch
33-
import nncf.torch.function_hook
34-
import nncf.torch.function_hook.prune
35-
import nncf.torch.function_hook.prune.prune_model
3629
from nncf.parameters import PruneMode
37-
from nncf.torch.function_hook.prune.magnitude.schedulers import MultiStepMagnitudePruningScheduler
30+
from nncf.torch.function_hook.pruning.magnitude.schedulers import MultiStepMagnitudePruningScheduler
31+
from nncf.torch.function_hook.pruning.rb.losses import RBLoss
32+
from nncf.torch.function_hook.pruning.rb.schedulers import MultiStepRBPruningScheduler
3833

3934
warnings.filterwarnings("ignore", category=TracerWarning)
4035
warnings.filterwarnings("ignore", category=UserWarning)
4136

4237
BASE_MODEL_NAME = "resnet18"
4338
IMAGE_SIZE = 64
4439
BATCH_SIZE = 128
45-
TRAINING_EPOCHS = 2
4640

4741

4842
ROOT = Path(__file__).parent.resolve()
49-
BEST_CKPT_NAME = "resnet18_int8_best.pt"
5043
CHECKPOINT_URL = (
5144
"https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/302_resnet18_fp32_v1.pth"
5245
)
5346
DATASET_URL = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
5447
DATASET_PATH = Path().home() / ".cache" / "nncf" / "datasets"
5548

5649

50+
def get_argument_parser() -> ArgumentParser:
51+
parser = ArgumentParser()
52+
parser.add_argument(
53+
"--mode",
54+
type=str,
55+
choices=["magnitude", "rb"],
56+
default="magnitude",
57+
help="Pruning mode to use. Choices are: magnitude, rb. Default is magnitude.",
58+
)
59+
return parser
60+
61+
5762
def download_dataset() -> Path:
5863
downloader = FastDownload(base=DATASET_PATH.resolve(), archive="downloaded", data="extracted")
5964
return downloader.get(DATASET_URL)
@@ -66,10 +71,10 @@ def load_checkpoint(model: nn.Module) -> tuple[nn.Module, float]:
6671

6772

6873
def get_resnet18_model(device: torch.device) -> nn.Module:
69-
num_classes = 200 # 200 is for Tiny ImageNet, default is 1000 for ImageNet
70-
model = models.resnet18(weights=None)
74+
model = resnet18(weights=None)
7175
# Update the last FC layer for Tiny ImageNet number of classes.
72-
model.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True)
76+
# 200 is for Tiny ImageNet, default is 1000 for ImageNet
77+
model.fc = nn.Linear(in_features=512, out_features=200, bias=True)
7378
model.to(device)
7479
return model
7580

@@ -78,6 +83,7 @@ def train_epoch(
7883
train_loader: DataLoader,
7984
model: nn.Module,
8085
criterion: nn.Module,
86+
rb_loss: RBLoss,
8187
optimizer: torch.optim.Optimizer,
8288
device: torch.device,
8389
):
@@ -91,50 +97,34 @@ def train_epoch(
9197
# Compute output.
9298
output = model(images)
9399
loss = criterion(output, target)
94-
100+
if rb_loss is not None:
101+
loss += rb_loss()
95102
# Compute gradient and do opt step.
96103
optimizer.zero_grad()
97104
loss.backward()
98105
optimizer.step()
99106

100107

101-
def validate(val_loader: DataLoader, model: nn.Module, device: torch.device) -> float:
102-
top1_sum = 0.0
103-
108+
@torch.no_grad()
109+
def validate(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, device: torch.device) -> float:
104110
# Switch to evaluate mode.
105111
model.eval()
106112

107-
with torch.no_grad():
108-
for images, target in track(val_loader, total=len(val_loader), description="Validation:"):
109-
images = images.to(device)
110-
target = target.to(device)
111-
112-
# Compute output.
113-
output = model(images)
114-
115-
# Measure accuracy and record loss.
116-
[acc1] = accuracy(output, target, topk=(1,))
117-
top1_sum += acc1.item()
118-
119-
num_samples = len(val_loader)
120-
top1_avg = top1_sum / num_samples
121-
return top1_avg
113+
correct = 0
114+
total = 0
122115

116+
for images, target in track(val_loader, total=len(val_loader), description="Validation:"):
117+
images = images.to(device)
118+
target = target.to(device)
123119

124-
def accuracy(output: torch.Tensor, target: torch.tensor, topk: tuple[int, ...] = (1,)):
125-
with torch.no_grad():
126-
maxk = max(topk)
127-
batch_size = target.size(0)
120+
output = model(images)
128121

129-
_, pred = output.topk(maxk, 1, True, True)
130-
pred = pred.t()
131-
correct = pred.eq(target.view(1, -1).expand_as(pred))
122+
_, preds = output.max(1)
123+
correct += preds.eq(target).sum().item()
124+
total += target.size(0)
132125

133-
res = []
134-
for k in topk:
135-
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
136-
res.append(correct_k.mul_(100.0 / batch_size))
137-
return res
126+
accuracy1 = 100.0 * correct / total
127+
return accuracy1
138128

139129

140130
def create_data_loaders() -> tuple[DataLoader, DataLoader]:
@@ -151,23 +141,12 @@ def create_data_loaders() -> tuple[DataLoader, DataLoader]:
151141
train_dataset = datasets.ImageFolder(
152142
train_dir,
153143
transforms.Compose(
154-
[
155-
transforms.Resize(IMAGE_SIZE),
156-
transforms.RandomHorizontalFlip(),
157-
transforms.ToTensor(),
158-
normalize,
159-
]
144+
[transforms.Resize(IMAGE_SIZE), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]
160145
),
161146
)
162147
val_dataset = datasets.ImageFolder(
163148
val_dir,
164-
transforms.Compose(
165-
[
166-
transforms.Resize(IMAGE_SIZE),
167-
transforms.ToTensor(),
168-
normalize,
169-
]
170-
),
149+
transforms.Compose([transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), normalize]),
171150
)
172151

173152
train_loader = DataLoader(
@@ -200,7 +179,10 @@ def prepare_tiny_imagenet_200(dataset_dir: Path) -> None:
200179
val_images_dir.rmdir()
201180

202181

203-
def main():
182+
def main() -> float:
183+
args = get_argument_parser().parse_args()
184+
pruning_mode = args.mode
185+
204186
torch.manual_seed(0)
205187
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
206188
print(f"Using {device} device")
@@ -212,51 +194,64 @@ def main():
212194
model = get_resnet18_model(device)
213195
model, acc1_fp32 = load_checkpoint(model)
214196

215-
print(f"Accuracy@1 of original FP32 model: {acc1_fp32}")
197+
print(f"Accuracy@1 of original FP32 model: {acc1_fp32:.2f}")
216198

217199
train_loader, val_loader = create_data_loaders()
218200
example_input = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).to(device)
219201

220202
###############################################################################
221203
# Step 2: Prune model
222-
print(os.linesep + "[Step 2] Prune model")
223-
224-
# Unstructured pruning with 70% sparsity ratio
225-
pruned_model = nncf.prune(
226-
model,
227-
mode=PruneMode.UNSTRUCTURED_MAGNITUDE_GLOBAL,
228-
ratio=0.7,
229-
ignored_scope=nncf.IgnoredScope(),
230-
examples_inputs=example_input,
231-
)
232-
233-
acc1_init = validate(val_loader, pruned_model, device)
234-
235-
print(f"Accuracy@1 of pruned model with 0.7 pruning ratio without fine-tuning: {acc1_init:.3f}")
236-
237-
###############################################################################
238-
# Step 3: Fine tune with multi step sparsity scheduler
239-
print(os.linesep + "[Step 3] Fine tune with multi step sparsity scheduler")
204+
print(os.linesep + "[Step 2]: Prune model and specify training parameters")
205+
206+
if pruning_mode == "magnitude":
207+
pruned_model = nncf.prune(
208+
model,
209+
mode=PruneMode.UNSTRUCTURED_MAGNITUDE_GLOBAL,
210+
ratio=0.7,
211+
ignored_scope=nncf.IgnoredScope(),
212+
examples_inputs=example_input,
213+
)
214+
num_epochs = 2
215+
rb_loss = None
216+
scheduler = MultiStepMagnitudePruningScheduler(
217+
model=model, mode=PruneMode.UNSTRUCTURED_MAGNITUDE_GLOBAL, steps={0: 0.5, 1: 0.7}
218+
)
219+
optimizer = torch.optim.Adam(pruned_model.parameters(), lr=1e-5)
220+
else:
221+
pruned_model = nncf.prune(
222+
model,
223+
mode=PruneMode.UNSTRUCTURED_REGULARIZATION_BASED,
224+
ignored_scope=nncf.IgnoredScope(),
225+
examples_inputs=example_input,
226+
)
227+
num_epochs = 30
228+
rb_loss = RBLoss(pruned_model, target_ratio=0.7, p=0.1).to(device)
229+
scheduler = MultiStepRBPruningScheduler(rb_loss, steps={0: 0.3, 5: 0.5, 10: 0.7})
230+
231+
# Set higher lr for mask parameters to achieve the target pruning ratio faster
232+
mask_params = [p for n, p in pruned_model.named_parameters() if "mask" in n]
233+
model_params = [p for n, p in pruned_model.named_parameters() if "mask" not in n]
234+
optimizer = torch.optim.Adam(
235+
[
236+
{"params": model_params, "lr": 1e-5},
237+
{"params": mask_params, "lr": 1e-2, "weight_decay": 0.0},
238+
]
239+
)
240240

241-
# Define loss function (criterion) and optimizer.
242241
criterion = nn.CrossEntropyLoss().to(device)
243-
compression_lr = 1e-5
244-
optimizer = torch.optim.Adam(pruned_model.parameters(), lr=compression_lr)
245242

246-
# Create prune scheduler with multi steps strategy
247-
pruning_scheduler = MultiStepMagnitudePruningScheduler(
248-
pruned_model, mode=PruneMode.UNSTRUCTURED_MAGNITUDE_GLOBAL, steps={0: 0.6, 1: 0.7}
249-
)
243+
###############################################################################
244+
# Step 3: Fine tune
245+
print(os.linesep + "[Step 3] Fine tune with multi step pruning ratio scheduler")
250246

251-
for epoch in range(2):
247+
for epoch in range(num_epochs):
252248
print(os.linesep + f"Train epoch: {epoch}")
249+
scheduler.step()
250+
train_epoch(train_loader, pruned_model, criterion, rb_loss, optimizer, device=device)
253251

254-
pruning_scheduler.step()
255-
256-
train_epoch(train_loader, pruned_model, criterion, optimizer, device=device)
257252
acc1 = validate(val_loader, pruned_model, device)
258-
# Show statistics of pruning
259-
print(f"Accuracy@1 of pruned model after {epoch} epoch ratio {pruning_scheduler.current_ratio}: {acc1:.3f}")
253+
print(f"Current pruning ratio: {scheduler.current_ratio:.3f}")
254+
print(f"Accuracy@1 of pruned model after {epoch} epoch: {acc1:.3f}")
260255

261256
###############################################################################
262257
# Step 4: Export models

src/nncf/parameters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,4 @@ class QuantizationMode(StrEnum):
221221
class PruneMode(StrEnum):
222222
UNSTRUCTURED_MAGNITUDE_LOCAL = auto()
223223
UNSTRUCTURED_MAGNITUDE_GLOBAL = auto()
224+
UNSTRUCTURED_REGULARIZATION_BASED = auto()

src/nncf/pruning/prune_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def prune(
2323
model: TModel,
2424
*,
2525
mode: PruneMode,
26-
ratio: float,
26+
ratio: Optional[float] = None,
2727
ignored_scope: Optional[IgnoredScope] = None,
2828
examples_inputs: Optional[Any] = None,
2929
) -> TModel:
@@ -40,7 +40,7 @@ def prune(
4040
"""
4141
backend = get_backend(model)
4242
if backend == BackendType.TORCH:
43-
from nncf.torch.function_hook.prune.prune_model import prune
43+
from nncf.torch.function_hook.pruning.prune_model import prune
4444

4545
model = prune(model, mode, ratio, ignored_scope, examples_inputs)
4646
else:
File renamed without changes.

src/nncf/torch/function_hook/prune/magnitude/__init__.py renamed to src/nncf/torch/function_hook/pruning/magnitude/__init__.py

File renamed without changes.

src/nncf/torch/function_hook/prune/magnitude/algo.py renamed to src/nncf/torch/function_hook/pruning/magnitude/algo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import nncf
1616
from nncf.parameters import PruneMode
1717
from nncf.torch.function_hook.hook_storage import decode_hook_name
18-
from nncf.torch.function_hook.prune.magnitude.modules import UnstructuredPruningMask
18+
from nncf.torch.function_hook.pruning.magnitude.modules import UnstructuredPruningMask
1919
from nncf.torch.function_hook.wrapper import get_hook_storage
2020
from nncf.torch.function_hook.wrapper import register_post_function_hook
2121
from nncf.torch.model_graph_manager import get_const_data_by_name
@@ -114,7 +114,7 @@ def update_pruning_ratio(
114114
new_mask = (abs_data > threshold).to(dtype=torch.bool)
115115

116116
# Set new mask
117-
hook.binary_mask = new_mask
117+
hook.binary_mask.copy_(new_mask)
118118

119119
elif mode == PruneMode.UNSTRUCTURED_MAGNITUDE_GLOBAL:
120120
# Get threshold value for all normalized weights
@@ -135,7 +135,7 @@ def update_pruning_ratio(
135135
new_mask = (norm_data > threshold_val).to(dtype=torch.bool)
136136

137137
# Set new mask
138-
hook.binary_mask = new_mask
138+
hook.binary_mask.copy_(new_mask)
139139
else:
140140
msg = f"Unsupported pruning mode: {mode}"
141141
raise nncf.InternalError(msg)

src/nncf/torch/function_hook/prune/magnitude/modules.py renamed to src/nncf/torch/function_hook/pruning/magnitude/modules.py

File renamed without changes.

0 commit comments

Comments
 (0)