Skip to content

Commit 629e385

Browse files
Merge strip_in_place functions (#3732)
### Changes Merge `apply_compression_in_place` and `apply_pruning_in_place` functions ### Ticket 176329 --------- Co-authored-by: Lyalyushkin Nikolay <[email protected]>
1 parent f832775 commit 629e385

File tree

2 files changed

+27
-80
lines changed

2 files changed

+27
-80
lines changed

src/nncf/torch/function_hook/pruning/strip.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

src/nncf/torch/function_hook/strip.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from nncf.parameters import StripFormat
2121
from nncf.torch.function_hook.hook_storage import decode_hook_name
2222
from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph
23-
from nncf.torch.function_hook.pruning.strip import apply_pruning_in_place
23+
from nncf.torch.function_hook.pruning.magnitude.modules import UnstructuredPruningMask
24+
from nncf.torch.function_hook.pruning.rb.modules import RBPruningMask
2425
from nncf.torch.function_hook.wrapper import get_hook_storage
2526
from nncf.torch.model_graph_manager import get_const_data
2627
from nncf.torch.model_graph_manager import get_const_node
@@ -57,7 +58,6 @@ def strip_model(model: TModel, example_input: Any = None, strip_format: StripFor
5758
elif strip_format == StripFormat.DQ:
5859
model = replace_quantizer_to_compressed_weight_with_decompressor(model)
5960
elif strip_format == StripFormat.IN_PLACE:
60-
model = apply_pruning_in_place(model)
6161
model = apply_compression_in_place(model)
6262
else:
6363
msg = f"Unsupported strip format: {strip_format}"
@@ -109,6 +109,7 @@ def replace_quantizer_to_torch_native_module(model: TModel, graph: NNCFGraph) ->
109109
return model
110110

111111

112+
@torch.no_grad()
112113
def replace_quantizer_to_compressed_weight_with_decompressor(model: TModel) -> TModel:
113114
"""
114115
Performs transformation from fake quantize format (FQ) to dequantization one (DQ):
@@ -136,12 +137,11 @@ def replace_quantizer_to_compressed_weight_with_decompressor(model: TModel) -> T
136137
module = get_module_by_name(module_name, model)
137138
weight_param = getattr(module, weight_attr_name)
138139

139-
with torch.no_grad():
140-
if isinstance(hook_module, AsymmetricQuantizer):
141-
decompressor, q_weight = asym_fq_to_decompressor(hook_module, weight_param)
142-
else:
143-
decompressor, q_weight = sym_fq_to_decompressor(hook_module, weight_param) # type: ignore[assignment]
144-
packed_tensor = decompressor.pack_weight(q_weight)
140+
if isinstance(hook_module, AsymmetricQuantizer):
141+
decompressor, q_weight = asym_fq_to_decompressor(hook_module, weight_param)
142+
else:
143+
decompressor, q_weight = sym_fq_to_decompressor(hook_module, weight_param) # type: ignore[assignment]
144+
packed_tensor = decompressor.pack_weight(q_weight)
145145

146146
weight_param.requires_grad = False
147147
weight_param.data = packed_tensor
@@ -150,40 +150,46 @@ def replace_quantizer_to_compressed_weight_with_decompressor(model: TModel) -> T
150150
return model
151151

152152

153+
@torch.no_grad()
153154
def apply_compression_in_place(model: TModel) -> TModel:
154155
"""
155-
Applies fake quantizers in-place to the weights:
156-
(weights + FQ) -> (fake quantized weights)
156+
Applies NNCF module in-place to the weights:
157+
(weights + NNCF module) -> (in-place compressed weights)
157158
158159
:param model: Compressed model
159-
:param graph: The model graph.
160160
:return: The modified NNCF network.
161161
"""
162162
hook_storage = get_hook_storage(model)
163-
164163
hooks_to_delete = []
165164
for hook_name, hook_module in hook_storage.named_hooks():
166-
if not isinstance(hook_module, (SymmetricQuantizer, AsymmetricQuantizer, BaseWeightsDecompressor)):
165+
if not isinstance(
166+
hook_module,
167+
(RBPruningMask, UnstructuredPruningMask, SymmetricQuantizer, AsymmetricQuantizer, BaseWeightsDecompressor),
168+
):
167169
continue
170+
168171
hook_module.eval()
172+
hook_type, op_name, port_id = decode_hook_name(hook_name)
173+
if hook_type != "post_hooks" or port_id != 0:
174+
msg = f"Unexpected place of Compression Module: {hook_type=}, {op_name=}, {port_id=}"
175+
raise nncf.InternalError(msg)
169176

170-
_, op_name, _ = decode_hook_name(hook_name)
171177
module_name, weight_attr_name = split_const_name(op_name)
172178
module = get_module_by_name(module_name, model)
173179
weight_param = getattr(module, weight_attr_name)
174180

175-
with torch.no_grad():
176-
if isinstance(hook_module, (SymmetricQuantizer, AsymmetricQuantizer)):
177-
fq_weight = hook_module.quantize(weight_param)
178-
else:
179-
fq_weight = hook_module(weight_param)
181+
if not isinstance(weight_param, torch.nn.Parameter):
182+
msg = f"Expected torch.nn.Parameter under {op_name}, got {type(weight_param)}."
183+
raise nncf.InternalError(msg)
180184

181185
weight_param.requires_grad = False
182-
weight_param.data = fq_weight
186+
if isinstance(hook_module, (SymmetricQuantizer, AsymmetricQuantizer)):
187+
weight_param.data = hook_module.quantize(weight_param)
188+
else:
189+
weight_param.data = hook_module(weight_param)
183190

184191
hooks_to_delete.append(hook_name)
185192

186193
for hook_name in hooks_to_delete:
187194
hook_storage.delete_hook(hook_name)
188-
189195
return model

0 commit comments

Comments
 (0)