2323from nncf .parameters import StripFormat
2424from nncf .torch .function_hook .wrapper import get_hook_storage
2525from nncf .torch .quantization .layers import BaseQuantizer
26+ from nncf .torch .quantization .layers import BaseWeightsDecompressor
2627from tests .torch .helpers import LinearModel
2728from tests .torch2 .function_hook .quantization .strip .test_strip_dequantize import check_compression_modules
2829
@@ -45,6 +46,16 @@ def extra_arguments(self) -> dict[str, Any]:
4546 args ["group_size" ] = - 1
4647 return args
4748
49+ @property
50+ def compression_class (self ) -> Any :
51+ return BaseWeightsDecompressor if self .compression_format == CompressionFormat .DQ else BaseQuantizer
52+
53+ @property
54+ def compression_dtype (self ) -> Any :
55+ if self .compression_format == CompressionFormat .DQ :
56+ return torch .int8 if self .mode == CompressWeightsMode .INT8_SYM else torch .uint8
57+ return self .torch_dtype
58+
4859
4960@pytest .mark .parametrize (
5061 "param" ,
@@ -57,7 +68,7 @@ def extra_arguments(self) -> dict[str, Any]:
5768 CompressWeightsMode .INT8_ASYM ,
5869 CompressWeightsMode .INT8_SYM ,
5970 ],
60- [CompressionFormat .FQ_LORA , CompressionFormat .FQ ],
71+ [CompressionFormat .FQ_LORA , CompressionFormat .FQ , CompressionFormat . DQ ],
6172 [torch .float32 , torch .float16 , torch .bfloat16 ],
6273 )
6374 ],
@@ -77,8 +88,8 @@ def test_nncf_in_place_strip(param: ParamInPlaceStrip):
7788 ** param .extra_arguments ,
7889 )
7990
80- check_compression_modules (compressed_model , expected_class = BaseQuantizer )
81- assert compressed_model .linear .weight .dtype == param .torch_dtype
91+ check_compression_modules (compressed_model , expected_class = param . compression_class )
92+ assert compressed_model .linear .weight .dtype == param .compression_dtype
8293
8394 with torch .no_grad ():
8495 compressed_output = compressed_model (example_input )
0 commit comments