Skip to content

Commit d1d65fc

Browse files
[PT] BatchNorm adaptation (#3726)
### Changes Add `nncf.batch_norm_adaptation` function. Add mode `bn_adaptaion` in example ### Related tickets 174483 ### Tests https://github.com/openvinotoolkit/nncf/actions/runs/19162855749/job/54776694802 --------- Co-authored-by: Lyalyushkin Nikolay <[email protected]>
1 parent 2d7597b commit d1d65fc

File tree

9 files changed

+292
-30
lines changed

9 files changed

+292
-30
lines changed

examples/pruning/torch/resnet18/README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ python3 -m pip install ../../../../ -r requirements.txt
3131
It's pretty simple. The example does not require additional preparation. It will do the preparation itself, such as loading the dataset and model, etc.
3232

3333
```bash
34+
# To run Magnitude-Based pruning
3435
python main.py
35-
# Or to run Regularization-Based pruning
36+
37+
# To run Magnitude-Based pruning with batch norm adaptation
38+
python main.py --mode mag_bn
39+
40+
# To run Regularization-Based pruning
3641
python main.py --mode rb
3742
```

examples/pruning/torch/resnet18/main.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import os
1313
import warnings
1414
from argparse import ArgumentParser
15+
from argparse import RawTextHelpFormatter
1516
from pathlib import Path
1617

1718
import openvino as ov
@@ -48,13 +49,18 @@
4849

4950

5051
def get_argument_parser() -> ArgumentParser:
51-
parser = ArgumentParser()
52+
parser = ArgumentParser(formatter_class=RawTextHelpFormatter)
5253
parser.add_argument(
5354
"--mode",
5455
type=str,
55-
choices=["magnitude", "rb"],
56-
default="magnitude",
57-
help="Pruning mode to use. Choices are: magnitude, rb. Default is magnitude.",
56+
choices=["mag", "mag_bn", "rb"],
57+
default="mag",
58+
help=(
59+
"Pruning mode to use. Choices are:\n"
60+
" - mag: Magnitude-based pruning with fine-tuning (default).\n"
61+
" - mag_bn: Magnitude-based pruning with BatchNorm adaptation without fine-tuning.\n"
62+
" - rb: Regularization-based pruning with fine-tuning.\n"
63+
),
5864
)
5965
return parser
6066

@@ -82,13 +88,12 @@ def get_resnet18_model(device: torch.device) -> nn.Module:
8288
def train_epoch(
8389
train_loader: DataLoader,
8490
model: nn.Module,
85-
criterion: nn.Module,
8691
rb_loss: RBLoss,
8792
optimizer: torch.optim.Optimizer,
8893
device: torch.device,
8994
):
90-
# Switch to train mode.
9195
model.train()
96+
criterion = nn.CrossEntropyLoss().to(device)
9297

9398
for images, target in track(train_loader, total=len(train_loader), description="Fine tuning:"):
9499
images = images.to(device)
@@ -107,7 +112,6 @@ def train_epoch(
107112

108113
@torch.no_grad()
109114
def validate(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, device: torch.device) -> float:
110-
# Switch to evaluate mode.
111115
model.eval()
112116

113117
correct = 0
@@ -201,14 +205,20 @@ def main() -> float:
201205

202206
###############################################################################
203207
# Step 2: Prune model
204-
print(os.linesep + "[Step 2]: Prune model and specify training parameters")
208+
print(os.linesep + "[Step 2] Prune model and specify training parameters")
205209

206-
if pruning_mode == "magnitude":
210+
if pruning_mode == "mag_bn":
211+
pruned_model = nncf.prune(
212+
model,
213+
mode=PruneMode.UNSTRUCTURED_MAGNITUDE_GLOBAL,
214+
ratio=0.6,
215+
examples_inputs=example_input,
216+
)
217+
elif pruning_mode == "mag":
207218
pruned_model = nncf.prune(
208219
model,
209220
mode=PruneMode.UNSTRUCTURED_MAGNITUDE_GLOBAL,
210221
ratio=0.7,
211-
ignored_scope=nncf.IgnoredScope(),
212222
examples_inputs=example_input,
213223
)
214224
num_epochs = 2
@@ -217,11 +227,10 @@ def main() -> float:
217227
model=model, mode=PruneMode.UNSTRUCTURED_MAGNITUDE_GLOBAL, steps={0: 0.5, 1: 0.7}
218228
)
219229
optimizer = torch.optim.Adam(pruned_model.parameters(), lr=1e-5)
220-
else:
230+
elif pruning_mode == "rb":
221231
pruned_model = nncf.prune(
222232
model,
223233
mode=PruneMode.UNSTRUCTURED_REGULARIZATION_BASED,
224-
ignored_scope=nncf.IgnoredScope(),
225234
examples_inputs=example_input,
226235
)
227236
num_epochs = 30
@@ -237,32 +246,52 @@ def main() -> float:
237246
{"params": mask_params, "lr": 1e-2, "weight_decay": 0.0},
238247
]
239248
)
240-
241-
criterion = nn.CrossEntropyLoss().to(device)
249+
else:
250+
msg = f"Unsupported pruning mode: {pruning_mode}, please choose from ['mag', 'mag_bn', 'rb']"
251+
raise ValueError(msg)
242252

243253
###############################################################################
244254
# Step 3: Fine tune
245255
print(os.linesep + "[Step 3] Fine tune with multi step pruning ratio scheduler")
246256

247-
for epoch in range(num_epochs):
248-
print(os.linesep + f"Train epoch: {epoch}")
249-
scheduler.step()
250-
train_epoch(train_loader, pruned_model, criterion, rb_loss, optimizer, device=device)
257+
if pruning_mode == "mag_bn":
258+
acc1_before = validate(val_loader, pruned_model, device)
259+
print(f"Accuracy@1 of pruned model before BatchNorm adaptation: {acc1_before:.3f}")
260+
261+
def transform_fn(batch: tuple[torch.Tensor, int]) -> torch.Tensor:
262+
inputs, _ = batch
263+
return inputs.to(device=device)
264+
265+
calibration_dataset = nncf.Dataset(train_loader, transform_func=transform_fn)
266+
267+
pruned_model = nncf.batch_norm_adaptation(
268+
pruned_model,
269+
calibration_dataset=calibration_dataset,
270+
num_iterations=200,
271+
)
251272

252273
acc1 = validate(val_loader, pruned_model, device)
253-
print(f"Current pruning ratio: {scheduler.current_ratio:.3f}")
254-
print(f"Accuracy@1 of pruned model after {epoch} epoch: {acc1:.3f}")
274+
print(f"Accuracy@1 of pruned model after BatchNorm adaptation: {acc1:.3f}")
275+
else:
276+
for epoch in range(num_epochs):
277+
print(os.linesep + f"Train epoch: {epoch}")
278+
scheduler.step()
279+
train_epoch(train_loader, pruned_model, rb_loss, optimizer, device=device)
280+
281+
acc1 = validate(val_loader, pruned_model, device)
282+
print(f"Current pruning ratio: {scheduler.current_ratio:.3f}")
283+
print(f"Accuracy@1 of pruned model after {epoch} epoch: {acc1:.3f}")
255284

256285
###############################################################################
257286
# Step 4: Print per tensor pruning statistics
258-
print(os.linesep + "[Step 4]: Pruning statistics")
287+
print(os.linesep + "[Step 4] Pruning statistics")
259288

260289
pruning_stat = nncf.pruning_statistic(pruned_model)
261290
print(pruning_stat)
262291

263292
###############################################################################
264293
# Step 5: Export models
265-
print(os.linesep + "[Step 5]: Export models")
294+
print(os.linesep + "[Step 5] Export models")
266295
ir_path = ROOT / f"{BASE_MODEL_NAME}_pruned.xml"
267296
ov_model = ov.convert_model(pruned_model.cpu(), example_input=example_input.cpu(), input=tuple(example_input.shape))
268297
ov.save_model(ov_model, ir_path, compress_to_fp16=False)

src/nncf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from nncf.parameters import SensitivityMetric as SensitivityMetric
4545
from nncf.parameters import StripFormat as StripFormat
4646
from nncf.parameters import TargetDevice as TargetDevice
47+
from nncf.pruning.prune_model import batch_norm_adaptation as batch_norm_adaptation
4748
from nncf.pruning.prune_model import prune as prune
4849
from nncf.pruning.prune_model import pruning_statistic as pruning_statistic
4950
from nncf.quantization import QuantizationPreset as QuantizationPreset

src/nncf/pruning/prune_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from nncf.common.utils.backend import BackendType
1818
from nncf.common.utils.backend import get_backend
1919
from nncf.common.utils.helpers import create_table
20+
from nncf.data.dataset import Dataset
2021
from nncf.parameters import PruneMode
2122
from nncf.scopes import IgnoredScope
2223

@@ -51,6 +52,29 @@ def prune(
5152
return model
5253

5354

55+
def batch_norm_adaptation(
56+
model: TModel, calibration_dataset: Dataset, *, num_iterations: Optional[int] = None
57+
) -> TModel:
58+
"""
59+
Adapt the batch normalization layers of the given model using the provided dataset.
60+
This function runs a specified number of iterations through the model
61+
to update the running statistics of the batch normalization layers.
62+
63+
:param model: The model to adapt.
64+
:param calibration_dataset: The dataset to use for the adaptation.
65+
:param num_iterations: The number of iterations to use for adaptation.
66+
If set to None, the adaptation will run for the entire dataset.
67+
"""
68+
backend = get_backend(model)
69+
if backend == BackendType.TORCH:
70+
from nncf.torch.function_hook.pruning.batch_norm_adaptation import batch_norm_adaptation
71+
72+
return batch_norm_adaptation(model, calibration_dataset=calibration_dataset, num_iterations=num_iterations)
73+
74+
msg = f"Batch norm adaptation is not supported for the {backend} backend."
75+
raise nncf.InternalError(msg)
76+
77+
5478
@dataclass
5579
class TensorPruningStatistic:
5680
"""
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from contextlib import contextmanager
13+
from typing import Generator, Optional, TypeVar
14+
15+
import torch
16+
from torch import nn
17+
18+
from nncf import Dataset
19+
from nncf.common.logging.track_progress import track
20+
21+
TModel = TypeVar("TModel", bound=nn.Module)
22+
23+
24+
@torch.no_grad()
25+
def batch_norm_adaptation(
26+
model: TModel, calibration_dataset: Dataset, *, num_iterations: Optional[int] = None
27+
) -> TModel:
28+
"""
29+
Adapt the batch normalization layers of the given model using the provided dataset.
30+
31+
This function runs a specified number of iterations (batches) through the model
32+
to update the running statistics of the batch normalization layers.
33+
34+
:param model: The model to adapt.
35+
:param calibration_dataset: The dataset to use for the adaptation.
36+
:param num_iterations: The number of iterations (batches) to use for adaptation.
37+
If set to None, the adaptation will run for the entire dataset.
38+
"""
39+
with set_batchnorm_train_only(model):
40+
total = calibration_dataset.get_length()
41+
if num_iterations is not None:
42+
total = min(num_iterations, total) if total is not None else num_iterations
43+
44+
for idx, input_data in track(
45+
enumerate(calibration_dataset.get_inference_data()),
46+
total=total,
47+
description="Batch norm adaptation",
48+
):
49+
if num_iterations is not None and idx >= num_iterations:
50+
break
51+
52+
if isinstance(input_data, dict):
53+
model(**input_data)
54+
elif isinstance(input_data, tuple):
55+
model(*input_data)
56+
else:
57+
model(input_data)
58+
59+
return model
60+
61+
62+
@contextmanager
63+
def set_batchnorm_train_only(model: nn.Module) -> Generator[None, None, None]:
64+
"""
65+
Context manager that sets only BatchNorm modules to train mode,
66+
while keeping all other modules in eval mode.
67+
Restores the original training states afterward.
68+
69+
:param model: The model.
70+
"""
71+
# Store the original training states
72+
original_states = {}
73+
for name, module in model.named_modules():
74+
original_states[name] = module.training
75+
76+
try:
77+
# Set all modules to eval, then only BN to train
78+
model.eval()
79+
for module in model.modules():
80+
if isinstance(module, nn.modules.batchnorm._BatchNorm):
81+
module.train()
82+
yield
83+
finally:
84+
# Restore original training states
85+
for name, module in model.named_modules():
86+
module.train(original_states[name])

src/nncf/torch/function_hook/pruning/magnitude/algo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from typing import TypeVar
13+
1214
import torch
1315
from torch import nn
1416

@@ -20,13 +22,15 @@
2022
from nncf.torch.function_hook.wrapper import register_post_function_hook
2123
from nncf.torch.model_graph_manager import get_const_data_by_name
2224

25+
TModel = TypeVar("TModel", bound=nn.Module)
26+
2327

2428
def apply_magnitude_pruning(
25-
model: nn.Module,
29+
model: TModel,
2630
parameters: list[str],
2731
mode: PruneMode,
2832
ratio: float,
29-
) -> nn.Module:
33+
) -> TModel:
3034
"""
3135
Prunes the specified parameters of the given model using unstructured pruning.
3236

src/nncf/torch/function_hook/pruning/prune_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11-
from typing import Any, Optional
11+
from typing import Any, Optional, TypeVar
1212

1313
from torch import nn
1414

@@ -25,6 +25,8 @@
2525
from nncf.torch.function_hook.wrapper import wrap_model
2626
from nncf.torch.model_graph_manager import get_const_node
2727

28+
TModel = TypeVar("TModel", bound=nn.Module)
29+
2830
OPERATORS_WITH_WEIGHTS_METATYPES = [
2931
om.PTConv1dMetatype,
3032
om.PTConv2dMetatype,
@@ -43,12 +45,12 @@
4345

4446

4547
def prune(
46-
model: nn.Module,
48+
model: TModel,
4749
mode: PruneMode,
4850
ratio: Optional[float] = None,
4951
ignored_scope: Optional[IgnoredScope] = None,
5052
examples_inputs: Optional[Any] = None,
51-
) -> nn.Module:
53+
) -> TModel:
5254
if examples_inputs is None:
5355
msg = "`sparsity` function requires `examples_inputs` argument to be specified for Torch backend"
5456
raise nncf.InternalError(msg)

src/nncf/torch/function_hook/pruning/rb/algo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from typing import TypeVar
13+
1214
from torch import nn
1315

1416
import nncf
@@ -18,11 +20,13 @@
1820
from nncf.torch.function_hook.wrapper import register_post_function_hook
1921
from nncf.torch.model_graph_manager import get_const_data_by_name
2022

23+
TModel = TypeVar("TModel", bound=nn.Module)
24+
2125

2226
def apply_regularization_based_pruning(
23-
model: nn.Module,
27+
model: TModel,
2428
parameters: list[str],
25-
) -> nn.Module:
29+
) -> TModel:
2630
"""
2731
:param model: The neural network model to be pruned.
2832
:param parameters: A list of parameter names to be pruned.

0 commit comments

Comments
 (0)