Skip to content

Commit f832775

Browse files
[PT] Pruning statistic (#3717)
### Changes Add `nncf.pruning_statistic` to collect and display information about pruning parameters. ```python pruning_stat = nncf.pruning_statistic(pruned_model) print(pruning_stat) ``` ``` ┍━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━┑ │ Parameter's name │ Shape │ Pruning ratio │ ┝━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━┥ │ conv.weight │ (3, 3, 3, 3) │ 0.506 │ ├────────────────────┼──────────────┼─────────────────┤ │ │ │ │ ├────────────────────┼──────────────┼─────────────────┤ │ Masked parameters │ │ 0.506 │ ├────────────────────┼──────────────┼─────────────────┤ │ All parameters │ │ 0.488 │ ┕━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━┙ ``` ### Related tickets 174484 ### Tests https://github.com/openvinotoolkit/nncf/actions/runs/19045571789/job/54392808345
1 parent 1ea7e45 commit f832775

File tree

6 files changed

+185
-1
lines changed

6 files changed

+185
-1
lines changed

examples/pruning/torch/resnet18/main.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,15 @@ def main() -> float:
254254
print(f"Accuracy@1 of pruned model after {epoch} epoch: {acc1:.3f}")
255255

256256
###############################################################################
257-
# Step 4: Export models
257+
# Step 4: Print per tensor pruning statistics
258+
print(os.linesep + "[Step 4]: Pruning statistics")
259+
260+
pruning_stat = nncf.pruning_statistic(pruned_model)
261+
print(pruning_stat)
262+
263+
###############################################################################
264+
# Step 5: Export models
265+
print(os.linesep + "[Step 5]: Export models")
258266
ir_path = ROOT / f"{BASE_MODEL_NAME}_pruned.xml"
259267
ov_model = ov.convert_model(pruned_model.cpu(), example_input=example_input.cpu(), input=tuple(example_input.shape))
260268
ov.save_model(ov_model, ir_path, compress_to_fp16=False)

src/nncf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from nncf.parameters import StripFormat as StripFormat
4646
from nncf.parameters import TargetDevice as TargetDevice
4747
from nncf.pruning.prune_model import prune as prune
48+
from nncf.pruning.prune_model import pruning_statistic as pruning_statistic
4849
from nncf.quantization import QuantizationPreset as QuantizationPreset
4950
from nncf.quantization import compress_weights as compress_weights
5051
from nncf.quantization import quantize as quantize

src/nncf/pruning/prune_model.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from dataclasses import dataclass
1213
from typing import Any, Optional
1314

1415
import nncf
1516
from nncf.api.compression import TModel
1617
from nncf.common.utils.backend import BackendType
1718
from nncf.common.utils.backend import get_backend
19+
from nncf.common.utils.helpers import create_table
1820
from nncf.parameters import PruneMode
1921
from nncf.scopes import IgnoredScope
2022

@@ -47,3 +49,61 @@ def prune(
4749
msg = f"Pruning is not supported for the {backend} backend."
4850
raise nncf.InternalError(msg)
4951
return model
52+
53+
54+
@dataclass
55+
class TensorPruningStatistic:
56+
"""
57+
Statistics about pruning for a single tensor.
58+
59+
:param tensor_name: Name of the tensor.
60+
:param shape: Shape of the tensor.
61+
:param pruned_ratio: Ratio of pruned elements in the tensor.
62+
"""
63+
64+
tensor_name: str
65+
shape: tuple[int, ...]
66+
pruned_ratio: float
67+
68+
69+
@dataclass
70+
class ModelPruningStatistic:
71+
"""
72+
Aggregated pruning statistics for a model.
73+
74+
:param pruning_ratio: Overall pruning ratio for pruned parameters in the model.
75+
:param global_pruning_ratio: Overall pruning ratio for all parameters in the model.
76+
:param pruned_tensors: List of pruning statistics for each tensor.
77+
"""
78+
79+
pruning_ratio: float
80+
global_pruning_ratio: float
81+
pruned_tensors: list[TensorPruningStatistic]
82+
83+
def __str__(self) -> str:
84+
total = [
85+
[None, None, None, None],
86+
["Prunable parameters", None, self.pruning_ratio],
87+
["All parameters", None, self.global_pruning_ratio],
88+
]
89+
90+
sorted_stat_per_tensor = sorted(self.pruned_tensors, key=lambda s: s.tensor_name)
91+
rows_per_tensor = [[s.tensor_name, s.shape, s.pruned_ratio] for s in sorted_stat_per_tensor]
92+
text = create_table(header=["Parameter's name", "Shape", "Pruning ratio"], rows=rows_per_tensor + total)
93+
return text
94+
95+
96+
def pruning_statistic(model: TModel) -> ModelPruningStatistic:
97+
"""
98+
Collects and returns pruning statistics for the given model.
99+
100+
:param model: The pruned model.
101+
:return: A pruning statistic.
102+
"""
103+
backend = get_backend(model)
104+
if backend == BackendType.TORCH:
105+
from nncf.torch.function_hook.pruning.statistics import pruning_statistic
106+
107+
return pruning_statistic(model)
108+
msg = f"Pruning statistics collection is not supported for the {backend} backend."
109+
raise nncf.InternalError(msg)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
13+
import torch
14+
from torch import nn
15+
16+
from nncf.pruning.prune_model import ModelPruningStatistic
17+
from nncf.pruning.prune_model import TensorPruningStatistic
18+
from nncf.torch.function_hook.hook_storage import decode_hook_name
19+
from nncf.torch.function_hook.pruning.magnitude.modules import UnstructuredPruningMask
20+
from nncf.torch.function_hook.pruning.rb.modules import RBPruningMask
21+
from nncf.torch.function_hook.pruning.rb.modules import binary_mask
22+
from nncf.torch.function_hook.wrapper import get_hook_storage
23+
24+
25+
@torch.no_grad()
26+
def pruning_statistic(model: nn.Module) -> ModelPruningStatistic:
27+
"""
28+
Collects and returns pruning statistics for the given model.
29+
30+
:param model: The pruned model.
31+
:return: Pruning statistics.
32+
"""
33+
total_params = sum(p.numel() for p in model.parameters())
34+
num_elements = 0
35+
pruned_elements = 0
36+
stat_per_tensors: list[TensorPruningStatistic] = []
37+
38+
hook_storage = get_hook_storage(model)
39+
for hook_name, hook_module in hook_storage.named_hooks():
40+
if isinstance(hook_module, UnstructuredPruningMask):
41+
mask = hook_module.binary_mask
42+
elif isinstance(hook_module, RBPruningMask):
43+
mask = binary_mask(hook_module.mask)
44+
# Exclude RBPruningMask’s internal mask parameters from the total parameter count
45+
total_params -= mask.numel()
46+
else:
47+
continue
48+
49+
pruned_el = int(torch.sum(mask == 0).item())
50+
num_el = mask.numel()
51+
shape = tuple(mask.shape)
52+
pruned_ratio = pruned_el / num_el if num_el != 0 else 0.0
53+
54+
_, tensor_name, _ = decode_hook_name(hook_name)
55+
56+
num_elements += num_el
57+
pruned_elements += pruned_el
58+
59+
stat_per_tensors.append(TensorPruningStatistic(tensor_name, shape, pruned_ratio))
60+
61+
masked_ratio = pruned_elements / num_elements if num_elements != 0 else 0.0
62+
global_ratio = pruned_elements / total_params if total_params != 0 else 0.0
63+
64+
return ModelPruningStatistic(
65+
pruning_ratio=masked_ratio,
66+
global_pruning_ratio=global_ratio,
67+
pruned_tensors=stat_per_tensors,
68+
)

tests/torch2/function_hook/pruning/magnitude/test_algo.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,23 @@ def test_save_load(tmpdir: Path):
149149
assert isinstance(d["post_hooks.conv:weight__0.0"], UnstructuredPruningMask)
150150

151151
assert torch.allclose(orig_output, loaded_output)
152+
153+
154+
def test_statistic():
155+
model = ConvModel()
156+
example_inputs = ConvModel.get_example_inputs()
157+
158+
pruned_model = nncf.prune(
159+
model, mode=PruneMode.UNSTRUCTURED_MAGNITUDE_LOCAL, ratio=0.5, examples_inputs=example_inputs
160+
)
161+
stat = nncf.pruning_statistic(pruned_model)
162+
163+
assert pytest.approx(stat.pruned_tensors[0].pruned_ratio, abs=1e-1) == 0.5
164+
assert stat.pruned_tensors[0].tensor_name == "conv.weight"
165+
assert stat.pruned_tensors[0].shape == (3, 3, 3, 3)
166+
assert pytest.approx(stat.pruning_ratio, abs=1e-2) == 0.5
167+
assert pytest.approx(stat.global_pruning_ratio, abs=1e-2) == 0.48
168+
169+
txt = str(stat)
170+
assert "conv.weight" in txt
171+
assert "All parameters" in txt

tests/torch2/function_hook/pruning/rb/test_algo.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,30 @@ def test_save_load(tmpdir: Path):
114114
assert isinstance(d["post_hooks.conv:weight__0.0"], RBPruningMask)
115115

116116
assert torch.allclose(orig_output, loaded_output)
117+
118+
119+
def test_statistic():
120+
model = ConvModel()
121+
example_inputs = ConvModel.get_example_inputs()
122+
123+
pruned_model = nncf.prune(
124+
model, mode=PruneMode.UNSTRUCTURED_REGULARIZATION_BASED, ratio=0.5, examples_inputs=example_inputs
125+
)
126+
127+
# Set mask
128+
with torch.no_grad():
129+
hook_storage = get_hook_storage(pruned_model)
130+
pruning_module = hook_storage.post_hooks["conv:weight__0"]["0"]
131+
pruning_module.mask[0] *= -1
132+
133+
stat = nncf.pruning_statistic(pruned_model)
134+
135+
assert pytest.approx(stat.pruned_tensors[0].pruned_ratio, abs=1e-1) == 0.3
136+
assert stat.pruned_tensors[0].tensor_name == "conv.weight"
137+
assert stat.pruned_tensors[0].shape == (3, 3, 3, 3)
138+
assert pytest.approx(stat.pruning_ratio, abs=1e-2) == 0.33
139+
assert pytest.approx(stat.global_pruning_ratio, abs=1e-2) == 0.32
140+
141+
txt = str(stat)
142+
assert "conv.weight" in txt
143+
assert "All parameters" in txt

0 commit comments

Comments
 (0)