Skip to content

Commit 61fc34a

Browse files
authored
[Torch] In-place strip for WeightDecompressor classes (#3709)
### Changes Extended in_place strip for CompressionFormat.DQ ### Reason for changes faster evaluation compressed models in torch ### Related tickets n/a ### Tests test_nncf_in_place_strip
1 parent 5d88dd6 commit 61fc34a

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

src/nncf/torch/function_hook/strip.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from nncf.torch.model_graph_manager import split_const_name
2828
from nncf.torch.quantization.layers import AsymmetricQuantizer
2929
from nncf.torch.quantization.layers import BaseQuantizer
30+
from nncf.torch.quantization.layers import BaseWeightsDecompressor
3031
from nncf.torch.quantization.layers import SymmetricQuantizer
3132
from nncf.torch.quantization.strip import asym_fq_to_decompressor
3233
from nncf.torch.quantization.strip import convert_to_torch_fakequantizer
@@ -167,7 +168,7 @@ def apply_compression_in_place(model: TModel, graph: NNCFGraph) -> TModel:
167168

168169
hooks_to_delete = []
169170
for name, hook in hook_storage.named_hooks():
170-
if not isinstance(hook, (SymmetricQuantizer, AsymmetricQuantizer)):
171+
if not isinstance(hook, (SymmetricQuantizer, AsymmetricQuantizer, BaseWeightsDecompressor)):
171172
continue
172173
_, op_name, _ = decode_hook_name(name)
173174
weight_node = graph.get_node_by_name(op_name)
@@ -181,7 +182,7 @@ def apply_compression_in_place(model: TModel, graph: NNCFGraph) -> TModel:
181182
raise nncf.InternalError(msg)
182183

183184
weight = get_const_data(weight_node, model)
184-
fq_weight = hook.quantize(weight)
185+
fq_weight = hook(weight) if isinstance(hook, BaseWeightsDecompressor) else hook.quantize(weight)
185186

186187
module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name)
187188
module = get_module_by_name(module_name, model)

tests/torch2/function_hook/quantization/strip/test_strip_in_place.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from nncf.parameters import StripFormat
2424
from nncf.torch.function_hook.wrapper import get_hook_storage
2525
from nncf.torch.quantization.layers import BaseQuantizer
26+
from nncf.torch.quantization.layers import BaseWeightsDecompressor
2627
from tests.torch.helpers import LinearModel
2728
from tests.torch2.function_hook.quantization.strip.test_strip_dequantize import check_compression_modules
2829

@@ -45,6 +46,16 @@ def extra_arguments(self) -> dict[str, Any]:
4546
args["group_size"] = -1
4647
return args
4748

49+
@property
50+
def compression_class(self) -> Any:
51+
return BaseWeightsDecompressor if self.compression_format == CompressionFormat.DQ else BaseQuantizer
52+
53+
@property
54+
def compression_dtype(self) -> Any:
55+
if self.compression_format == CompressionFormat.DQ:
56+
return torch.int8 if self.mode == CompressWeightsMode.INT8_SYM else torch.uint8
57+
return self.torch_dtype
58+
4859

4960
@pytest.mark.parametrize(
5061
"param",
@@ -57,7 +68,7 @@ def extra_arguments(self) -> dict[str, Any]:
5768
CompressWeightsMode.INT8_ASYM,
5869
CompressWeightsMode.INT8_SYM,
5970
],
60-
[CompressionFormat.FQ_LORA, CompressionFormat.FQ],
71+
[CompressionFormat.FQ_LORA, CompressionFormat.FQ, CompressionFormat.DQ],
6172
[torch.float32, torch.float16, torch.bfloat16],
6273
)
6374
],
@@ -77,8 +88,8 @@ def test_nncf_in_place_strip(param: ParamInPlaceStrip):
7788
**param.extra_arguments,
7889
)
7990

80-
check_compression_modules(compressed_model, expected_class=BaseQuantizer)
81-
assert compressed_model.linear.weight.dtype == param.torch_dtype
91+
check_compression_modules(compressed_model, expected_class=param.compression_class)
92+
assert compressed_model.linear.weight.dtype == param.compression_dtype
8293

8394
with torch.no_grad():
8495
compressed_output = compressed_model(example_input)

0 commit comments

Comments
 (0)