Skip to content

Commit 1ea7e45

Browse files
[PT] Strip pruned model (#3716)
### Changes Add strip function for pruned model Simplify strip to avoid build graph for `IN_PLACE` and `DQ` format. ### Related tickets 176026 ### Tests https://github.com/openvinotoolkit/nncf/actions/runs/19017695102
1 parent f09af53 commit 1ea7e45

File tree

12 files changed

+196
-67
lines changed

12 files changed

+196
-67
lines changed

docs/usage/training_time_compression/quantization_aware_training_lora/Usage.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,5 @@ To convert a PyTorch model to an INT4 OpenVINO model, transform the `FQ_LORA` or
9898

9999
```python
100100
# Convert to OpenVINO format after training is complete
101-
compressed_model = nncf.strip(model, strip_format=StripFormat.DQ, example_input=example_input)
101+
compressed_model = nncf.strip(model, strip_format=StripFormat.DQ)
102102
```

examples/llm_compression/torch/distillation_qat_with_lora/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,8 @@ def export_to_openvino(pretrained: str, ckpt_file: Path, ir_dir: Path) -> OVMode
228228
:return: A wrapper of OpenVINO model ready for evaluation.
229229
"""
230230
model_to_eval = AutoModelForCausalLM.from_pretrained(pretrained, torch_dtype=torch.float32, device_map="cpu")
231-
example_input = model_to_eval.dummy_inputs
232231
model_to_eval = load_checkpoint(model_to_eval, ckpt_file)
233-
model_to_eval = nncf.strip(model_to_eval, do_copy=False, strip_format=StripFormat.DQ, example_input=example_input)
232+
model_to_eval = nncf.strip(model_to_eval, do_copy=False, strip_format=StripFormat.DQ)
234233
export_from_model(model_to_eval, ir_dir, device="cpu")
235234
return OVModelForCausalLM.from_pretrained(
236235
model_id=ir_dir,

examples/llm_compression/torch/downstream_qat_with_nls/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,14 +446,13 @@ def export_to_openvino(
446446
:return: A wrapper of OpenVINO model ready for evaluation.
447447
"""
448448
model_to_eval = AutoModelForCausalLM.from_pretrained(pretrained, torch_dtype=torch.float32, device_map="cpu")
449-
example_input = model_to_eval.dummy_inputs
450449
model_to_eval = load_checkpoint(model_to_eval, ckpt_file)
451450
if specific_rank_config is not None:
452451
configure_lora_adapters(
453452
get_layer_id_vs_lora_quantizers_map(model_to_eval),
454453
specific_rank_config=specific_rank_config,
455454
)
456-
model_to_eval = nncf.strip(model_to_eval, do_copy=False, strip_format=StripFormat.DQ, example_input=example_input)
455+
model_to_eval = nncf.strip(model_to_eval, do_copy=False, strip_format=StripFormat.DQ)
457456
export_from_model(model_to_eval, ir_dir, device="cpu")
458457
return OVModelForCausalLM.from_pretrained(
459458
model_id=ir_dir,

src/nncf/parameters.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ class StripFormat(StrEnum):
150150
:param DQ: Replaces FakeQuantize operations with a dequantization subgraph and stores compressed weights
151151
in low-bit precision using fake quantize parameters. This is the default format for deploying models
152152
with compressed weights.
153-
:param IN_PLACE: Directly applies fake quantizers to the weights, replacing the original weights with their
154-
fake quantized versions.
153+
:param IN_PLACE: Directly applies NNCF operations to the weights, replacing the original weights.
155154
"""
156155

157156
NATIVE = "native"

src/nncf/torch/function_hook/hook_storage.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ def delete_hook(self, hook_name: str) -> None:
229229
raise ValueError(msg)
230230

231231
del storage_dict[hook_key][hook_id]
232+
if not storage_dict[hook_key]:
233+
del storage_dict[hook_key]
232234

233235

234236
def decode_hook_name(hook_name: str) -> tuple[str, str, int]:
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
from typing import TypeVar
12+
13+
import torch
14+
from torch import nn
15+
16+
import nncf
17+
from nncf.torch.function_hook.hook_storage import decode_hook_name
18+
from nncf.torch.function_hook.pruning.magnitude.modules import UnstructuredPruningMask
19+
from nncf.torch.function_hook.pruning.rb.modules import RBPruningMask
20+
from nncf.torch.function_hook.wrapper import get_hook_storage
21+
from nncf.torch.model_graph_manager import get_module_by_name
22+
from nncf.torch.model_graph_manager import split_const_name
23+
24+
TModel = TypeVar("TModel", bound=nn.Module)
25+
26+
27+
@torch.no_grad()
28+
def apply_pruning_in_place(model: TModel) -> TModel:
29+
"""
30+
Applies pruning masks in-place to the weights:
31+
(weights + pruning mask) -> (pruned weights)
32+
33+
:param model: Compressed model
34+
:return: The modified NNCF network.
35+
"""
36+
hook_storage = get_hook_storage(model)
37+
hooks_to_delete = []
38+
for hook_name, hook_module in hook_storage.named_hooks():
39+
if not isinstance(hook_module, (RBPruningMask, UnstructuredPruningMask)):
40+
continue
41+
42+
hook_module.eval()
43+
hook_type, op_name, port_id = decode_hook_name(hook_name)
44+
if hook_type != "post_hooks" or port_id != 0:
45+
msg = f"Unexpected place of SparsityBinaryMask: {hook_type=}, {op_name=}, {port_id=}"
46+
raise nncf.InternalError(msg)
47+
48+
module_name, weight_attr_name = split_const_name(op_name)
49+
module = get_module_by_name(module_name, model)
50+
weight_param = getattr(module, weight_attr_name)
51+
52+
weight_param.requires_grad = False
53+
weight_param.data = hook_module(weight_param)
54+
55+
hooks_to_delete.append(hook_name)
56+
57+
for hook_name in hooks_to_delete:
58+
hook_storage.delete_hook(hook_name)
59+
return model

src/nncf/torch/function_hook/strip.py

Lines changed: 38 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from nncf.parameters import StripFormat
2121
from nncf.torch.function_hook.hook_storage import decode_hook_name
2222
from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph
23+
from nncf.torch.function_hook.pruning.strip import apply_pruning_in_place
2324
from nncf.torch.function_hook.wrapper import get_hook_storage
2425
from nncf.torch.model_graph_manager import get_const_data
2526
from nncf.torch.model_graph_manager import get_const_node
@@ -36,7 +37,7 @@
3637
TModel = TypeVar("TModel", bound=nn.Module)
3738

3839

39-
def strip_quantized_model(model: TModel, example_input: Any, strip_format: StripFormat = StripFormat.NATIVE) -> TModel:
40+
def strip_model(model: TModel, example_input: Any = None, strip_format: StripFormat = StripFormat.NATIVE) -> TModel:
4041
"""
4142
Removes auxiliary layers and operations added during the quantization process,
4243
resulting in a clean quantized model ready for deployment. The functionality of the model object is still preserved
@@ -47,14 +48,17 @@ def strip_quantized_model(model: TModel, example_input: Any, strip_format: Strip
4748
:param strip_format: Describes the format in which model is saved after strip.
4849
:return: The modified NNCF network.
4950
"""
50-
graph = build_nncf_graph(model, example_input)
51-
5251
if strip_format == StripFormat.NATIVE:
52+
if example_input is None:
53+
msg = "The example_input parameter is required to strip the model."
54+
raise nncf.InternalError(msg)
55+
graph = build_nncf_graph(model, example_input)
5356
model = replace_quantizer_to_torch_native_module(model, graph)
5457
elif strip_format == StripFormat.DQ:
55-
model = replace_quantizer_to_compressed_weight_with_decompressor(model, graph)
58+
model = replace_quantizer_to_compressed_weight_with_decompressor(model)
5659
elif strip_format == StripFormat.IN_PLACE:
57-
model = apply_compression_in_place(model, graph)
60+
model = apply_pruning_in_place(model)
61+
model = apply_compression_in_place(model)
5862
else:
5963
msg = f"Unsupported strip format: {strip_format}"
6064
raise nncf.ParameterNotSupportedError(msg)
@@ -105,57 +109,48 @@ def replace_quantizer_to_torch_native_module(model: TModel, graph: NNCFGraph) ->
105109
return model
106110

107111

108-
def replace_quantizer_to_compressed_weight_with_decompressor(model: TModel, graph: NNCFGraph) -> TModel:
112+
def replace_quantizer_to_compressed_weight_with_decompressor(model: TModel) -> TModel:
109113
"""
110114
Performs transformation from fake quantize format (FQ) to dequantization one (DQ):
111115
(weights + FQ) -> (compressed_weights + DQ)
112116
113117
:param model: Compressed model
114-
:param graph: The model graph.
115118
:return: The modified NNCF network.
116119
"""
117120
hook_storage = get_hook_storage(model)
118121

119-
for name, module in hook_storage.named_hooks():
120-
if not isinstance(module, (SymmetricQuantizer, AsymmetricQuantizer)):
122+
for hook_name, hook_module in hook_storage.named_hooks():
123+
if not isinstance(hook_module, (SymmetricQuantizer, AsymmetricQuantizer)):
121124
continue
122125
msg = ""
123-
if module._qspec.half_range or module._qspec.narrow_range:
126+
if hook_module._qspec.half_range or hook_module._qspec.narrow_range:
124127
msg += "Unexpected parameters of quantizers on strip: half_range and narrow_range should be False.\n"
125-
if module.num_bits not in [4, 8]:
126-
msg += f"Unsupported number of bits {module.num_bits} for the quantizer {module}.\n"
128+
if hook_module.num_bits not in [4, 8]:
129+
msg += f"Unsupported number of bits {hook_module.num_bits} for the quantizer {hook_module}.\n"
127130
if msg:
128131
raise nncf.ValidationError(msg)
129132

130-
_, op_name, _ = decode_hook_name(name)
131-
weight_node = graph.get_node_by_name(op_name)
132-
133-
if weight_node is None:
134-
msg = "FQ is not assigned to weight. Strip to DQ format is not supported for FQ on activation."
135-
raise nncf.UnsupportedModelError(msg)
136-
137-
if not isinstance(weight_node.layer_attributes, ConstantLayerAttributes):
138-
msg = f"Unexpected layer attributes type {type(weight_node.layer_attributes)}"
139-
raise nncf.InternalError(msg)
140-
141-
weight = get_const_data(weight_node, model)
133+
_, op_name, _ = decode_hook_name(hook_name)
142134

143-
convert_fn = asym_fq_to_decompressor if isinstance(module, AsymmetricQuantizer) else sym_fq_to_decompressor
144-
decompressor, q_weight = convert_fn(module, weight) # type: ignore[operator]
145-
packed_tensor = decompressor.pack_weight(q_weight)
146-
147-
module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name)
135+
module_name, weight_attr_name = split_const_name(op_name)
148136
module = get_module_by_name(module_name, model)
149137
weight_param = getattr(module, weight_attr_name)
150138

139+
with torch.no_grad():
140+
if isinstance(hook_module, AsymmetricQuantizer):
141+
decompressor, q_weight = asym_fq_to_decompressor(hook_module, weight_param)
142+
else:
143+
decompressor, q_weight = sym_fq_to_decompressor(hook_module, weight_param) # type: ignore[assignment]
144+
packed_tensor = decompressor.pack_weight(q_weight)
145+
151146
weight_param.requires_grad = False
152147
weight_param.data = packed_tensor
153148

154-
hook_storage.set_submodule(name, decompressor)
149+
hook_storage.set_submodule(hook_name, decompressor)
155150
return model
156151

157152

158-
def apply_compression_in_place(model: TModel, graph: NNCFGraph) -> TModel:
153+
def apply_compression_in_place(model: TModel) -> TModel:
159154
"""
160155
Applies fake quantizers in-place to the weights:
161156
(weights + FQ) -> (fake quantized weights)
@@ -167,31 +162,26 @@ def apply_compression_in_place(model: TModel, graph: NNCFGraph) -> TModel:
167162
hook_storage = get_hook_storage(model)
168163

169164
hooks_to_delete = []
170-
for name, hook in hook_storage.named_hooks():
171-
if not isinstance(hook, (SymmetricQuantizer, AsymmetricQuantizer, BaseWeightsDecompressor)):
165+
for hook_name, hook_module in hook_storage.named_hooks():
166+
if not isinstance(hook_module, (SymmetricQuantizer, AsymmetricQuantizer, BaseWeightsDecompressor)):
172167
continue
173-
_, op_name, _ = decode_hook_name(name)
174-
weight_node = graph.get_node_by_name(op_name)
168+
hook_module.eval()
175169

176-
if weight_node is None:
177-
msg = "FQ is not assigned to weight. In-place strip is not supported for FQ on activation."
178-
raise nncf.UnsupportedModelError(msg)
179-
180-
if not isinstance(weight_node.layer_attributes, ConstantLayerAttributes):
181-
msg = f"Unexpected layer attributes type {type(weight_node.layer_attributes)}"
182-
raise nncf.InternalError(msg)
183-
184-
weight = get_const_data(weight_node, model)
185-
fq_weight = hook(weight) if isinstance(hook, BaseWeightsDecompressor) else hook.quantize(weight)
186-
187-
module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name)
170+
_, op_name, _ = decode_hook_name(hook_name)
171+
module_name, weight_attr_name = split_const_name(op_name)
188172
module = get_module_by_name(module_name, model)
189173
weight_param = getattr(module, weight_attr_name)
190174

175+
with torch.no_grad():
176+
if isinstance(hook_module, (SymmetricQuantizer, AsymmetricQuantizer)):
177+
fq_weight = hook_module.quantize(weight_param)
178+
else:
179+
fq_weight = hook_module(weight_param)
180+
191181
weight_param.requires_grad = False
192182
weight_param.data = fq_weight
193183

194-
hooks_to_delete.append(name)
184+
hooks_to_delete.append(hook_name)
195185

196186
for hook_name in hooks_to_delete:
197187
hook_storage.delete_hook(hook_name)

src/nncf/torch/strip.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111

1212

1313
from copy import deepcopy
14-
from typing import Any, Optional, TypeVar
14+
from typing import Any, TypeVar
1515

1616
from torch import nn
1717

18-
import nncf
1918
from nncf.common.check_features import is_torch_tracing_by_patching
2019
from nncf.parameters import StripFormat
2120

@@ -26,7 +25,7 @@ def strip(
2625
model: TModel,
2726
do_copy: bool = True,
2827
strip_format: StripFormat = StripFormat.NATIVE,
29-
example_input: Optional[Any] = None,
28+
example_input: Any = None,
3029
) -> TModel:
3130
"""
3231
Removes auxiliary layers and operations added during the compression process, resulting in a clean
@@ -41,10 +40,7 @@ def strip(
4140
if is_torch_tracing_by_patching():
4241
return model.nncf.strip(do_copy, strip_format)
4342

44-
from nncf.torch.function_hook.strip import strip_quantized_model
43+
from nncf.torch.function_hook.strip import strip_model
4544

46-
if example_input is None:
47-
msg = "Required example_input for strip model."
48-
raise nncf.InternalError(msg)
4945
model = deepcopy(model) if do_copy else model
50-
return strip_quantized_model(model, example_input, strip_format)
46+
return strip_model(model, example_input, strip_format)

tests/torch2/function_hook/pruning/helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def get_example_inputs():
2222
def __init__(self) -> None:
2323
super().__init__()
2424
self.conv = nn.Conv2d(3, 3, 3)
25+
self.conv.weight.data = torch.arange(1, 82, dtype=torch.float32).view(3, 3, 3, 3)
2526

2627
def forward(self, x: torch.Tensor):
2728
x = self.conv(x)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
13+
import torch
14+
15+
import nncf
16+
from nncf import PruneMode
17+
from nncf.torch.function_hook.pruning.magnitude.modules import UnstructuredPruningMask
18+
from nncf.torch.function_hook.wrapper import get_hook_storage
19+
from tests.torch2.function_hook.pruning.helpers import ConvModel
20+
21+
22+
def test_strip():
23+
model = ConvModel()
24+
example_inputs = ConvModel.get_example_inputs()
25+
pruned_model = nncf.prune(
26+
model, mode=PruneMode.UNSTRUCTURED_MAGNITUDE_LOCAL, ratio=0.5, examples_inputs=example_inputs
27+
)
28+
pruned_model.eval()
29+
30+
hook_storage = get_hook_storage(pruned_model)
31+
pruning_module = hook_storage.post_hooks["conv:weight__0"]["0"]
32+
33+
assert isinstance(pruning_module, UnstructuredPruningMask)
34+
35+
with torch.no_grad():
36+
pruned_weight = pruning_module(pruned_model.conv.weight)
37+
38+
striped_model = nncf.strip(pruned_model, strip_format=nncf.StripFormat.IN_PLACE, do_copy=False)
39+
hook_storage = get_hook_storage(striped_model)
40+
41+
assert not list(hook_storage.named_hooks())
42+
assert torch.equal(striped_model.conv.weight, pruned_weight)
43+
assert torch.count_nonzero(striped_model.conv.weight) == 40

0 commit comments

Comments
 (0)