2020from nncf .parameters import StripFormat
2121from nncf .torch .function_hook .hook_storage import decode_hook_name
2222from nncf .torch .function_hook .nncf_graph .nncf_graph_builder import build_nncf_graph
23+ from nncf .torch .function_hook .pruning .strip import apply_pruning_in_place
2324from nncf .torch .function_hook .wrapper import get_hook_storage
2425from nncf .torch .model_graph_manager import get_const_data
2526from nncf .torch .model_graph_manager import get_const_node
3637TModel = TypeVar ("TModel" , bound = nn .Module )
3738
3839
39- def strip_quantized_model (model : TModel , example_input : Any , strip_format : StripFormat = StripFormat .NATIVE ) -> TModel :
40+ def strip_model (model : TModel , example_input : Any = None , strip_format : StripFormat = StripFormat .NATIVE ) -> TModel :
4041 """
4142 Removes auxiliary layers and operations added during the quantization process,
4243 resulting in a clean quantized model ready for deployment. The functionality of the model object is still preserved
@@ -47,14 +48,17 @@ def strip_quantized_model(model: TModel, example_input: Any, strip_format: Strip
4748 :param strip_format: Describes the format in which model is saved after strip.
4849 :return: The modified NNCF network.
4950 """
50- graph = build_nncf_graph (model , example_input )
51-
5251 if strip_format == StripFormat .NATIVE :
52+ if example_input is None :
53+ msg = "The example_input parameter is required to strip the model."
54+ raise nncf .InternalError (msg )
55+ graph = build_nncf_graph (model , example_input )
5356 model = replace_quantizer_to_torch_native_module (model , graph )
5457 elif strip_format == StripFormat .DQ :
55- model = replace_quantizer_to_compressed_weight_with_decompressor (model , graph )
58+ model = replace_quantizer_to_compressed_weight_with_decompressor (model )
5659 elif strip_format == StripFormat .IN_PLACE :
57- model = apply_compression_in_place (model , graph )
60+ model = apply_pruning_in_place (model )
61+ model = apply_compression_in_place (model )
5862 else :
5963 msg = f"Unsupported strip format: { strip_format } "
6064 raise nncf .ParameterNotSupportedError (msg )
@@ -105,57 +109,48 @@ def replace_quantizer_to_torch_native_module(model: TModel, graph: NNCFGraph) ->
105109 return model
106110
107111
108- def replace_quantizer_to_compressed_weight_with_decompressor (model : TModel , graph : NNCFGraph ) -> TModel :
112+ def replace_quantizer_to_compressed_weight_with_decompressor (model : TModel ) -> TModel :
109113 """
110114 Performs transformation from fake quantize format (FQ) to dequantization one (DQ):
111115 (weights + FQ) -> (compressed_weights + DQ)
112116
113117 :param model: Compressed model
114- :param graph: The model graph.
115118 :return: The modified NNCF network.
116119 """
117120 hook_storage = get_hook_storage (model )
118121
119- for name , module in hook_storage .named_hooks ():
120- if not isinstance (module , (SymmetricQuantizer , AsymmetricQuantizer )):
122+ for hook_name , hook_module in hook_storage .named_hooks ():
123+ if not isinstance (hook_module , (SymmetricQuantizer , AsymmetricQuantizer )):
121124 continue
122125 msg = ""
123- if module ._qspec .half_range or module ._qspec .narrow_range :
126+ if hook_module ._qspec .half_range or hook_module ._qspec .narrow_range :
124127 msg += "Unexpected parameters of quantizers on strip: half_range and narrow_range should be False.\n "
125- if module .num_bits not in [4 , 8 ]:
126- msg += f"Unsupported number of bits { module .num_bits } for the quantizer { module } .\n "
128+ if hook_module .num_bits not in [4 , 8 ]:
129+ msg += f"Unsupported number of bits { hook_module .num_bits } for the quantizer { hook_module } .\n "
127130 if msg :
128131 raise nncf .ValidationError (msg )
129132
130- _ , op_name , _ = decode_hook_name (name )
131- weight_node = graph .get_node_by_name (op_name )
132-
133- if weight_node is None :
134- msg = "FQ is not assigned to weight. Strip to DQ format is not supported for FQ on activation."
135- raise nncf .UnsupportedModelError (msg )
136-
137- if not isinstance (weight_node .layer_attributes , ConstantLayerAttributes ):
138- msg = f"Unexpected layer attributes type { type (weight_node .layer_attributes )} "
139- raise nncf .InternalError (msg )
140-
141- weight = get_const_data (weight_node , model )
133+ _ , op_name , _ = decode_hook_name (hook_name )
142134
143- convert_fn = asym_fq_to_decompressor if isinstance (module , AsymmetricQuantizer ) else sym_fq_to_decompressor
144- decompressor , q_weight = convert_fn (module , weight ) # type: ignore[operator]
145- packed_tensor = decompressor .pack_weight (q_weight )
146-
147- module_name , weight_attr_name = split_const_name (weight_node .layer_attributes .name )
135+ module_name , weight_attr_name = split_const_name (op_name )
148136 module = get_module_by_name (module_name , model )
149137 weight_param = getattr (module , weight_attr_name )
150138
139+ with torch .no_grad ():
140+ if isinstance (hook_module , AsymmetricQuantizer ):
141+ decompressor , q_weight = asym_fq_to_decompressor (hook_module , weight_param )
142+ else :
143+ decompressor , q_weight = sym_fq_to_decompressor (hook_module , weight_param ) # type: ignore[assignment]
144+ packed_tensor = decompressor .pack_weight (q_weight )
145+
151146 weight_param .requires_grad = False
152147 weight_param .data = packed_tensor
153148
154- hook_storage .set_submodule (name , decompressor )
149+ hook_storage .set_submodule (hook_name , decompressor )
155150 return model
156151
157152
158- def apply_compression_in_place (model : TModel , graph : NNCFGraph ) -> TModel :
153+ def apply_compression_in_place (model : TModel ) -> TModel :
159154 """
160155 Applies fake quantizers in-place to the weights:
161156 (weights + FQ) -> (fake quantized weights)
@@ -167,31 +162,26 @@ def apply_compression_in_place(model: TModel, graph: NNCFGraph) -> TModel:
167162 hook_storage = get_hook_storage (model )
168163
169164 hooks_to_delete = []
170- for name , hook in hook_storage .named_hooks ():
171- if not isinstance (hook , (SymmetricQuantizer , AsymmetricQuantizer , BaseWeightsDecompressor )):
165+ for hook_name , hook_module in hook_storage .named_hooks ():
166+ if not isinstance (hook_module , (SymmetricQuantizer , AsymmetricQuantizer , BaseWeightsDecompressor )):
172167 continue
173- _ , op_name , _ = decode_hook_name (name )
174- weight_node = graph .get_node_by_name (op_name )
168+ hook_module .eval ()
175169
176- if weight_node is None :
177- msg = "FQ is not assigned to weight. In-place strip is not supported for FQ on activation."
178- raise nncf .UnsupportedModelError (msg )
179-
180- if not isinstance (weight_node .layer_attributes , ConstantLayerAttributes ):
181- msg = f"Unexpected layer attributes type { type (weight_node .layer_attributes )} "
182- raise nncf .InternalError (msg )
183-
184- weight = get_const_data (weight_node , model )
185- fq_weight = hook (weight ) if isinstance (hook , BaseWeightsDecompressor ) else hook .quantize (weight )
186-
187- module_name , weight_attr_name = split_const_name (weight_node .layer_attributes .name )
170+ _ , op_name , _ = decode_hook_name (hook_name )
171+ module_name , weight_attr_name = split_const_name (op_name )
188172 module = get_module_by_name (module_name , model )
189173 weight_param = getattr (module , weight_attr_name )
190174
175+ with torch .no_grad ():
176+ if isinstance (hook_module , (SymmetricQuantizer , AsymmetricQuantizer )):
177+ fq_weight = hook_module .quantize (weight_param )
178+ else :
179+ fq_weight = hook_module (weight_param )
180+
191181 weight_param .requires_grad = False
192182 weight_param .data = fq_weight
193183
194- hooks_to_delete .append (name )
184+ hooks_to_delete .append (hook_name )
195185
196186 for hook_name in hooks_to_delete :
197187 hook_storage .delete_hook (hook_name )
0 commit comments