2020from nncf .parameters import StripFormat
2121from nncf .torch .function_hook .hook_storage import decode_hook_name
2222from nncf .torch .function_hook .nncf_graph .nncf_graph_builder import build_nncf_graph
23- from nncf .torch .function_hook .pruning .strip import apply_pruning_in_place
23+ from nncf .torch .function_hook .pruning .magnitude .modules import UnstructuredPruningMask
24+ from nncf .torch .function_hook .pruning .rb .modules import RBPruningMask
2425from nncf .torch .function_hook .wrapper import get_hook_storage
2526from nncf .torch .model_graph_manager import get_const_data
2627from nncf .torch .model_graph_manager import get_const_node
@@ -57,7 +58,6 @@ def strip_model(model: TModel, example_input: Any = None, strip_format: StripFor
5758 elif strip_format == StripFormat .DQ :
5859 model = replace_quantizer_to_compressed_weight_with_decompressor (model )
5960 elif strip_format == StripFormat .IN_PLACE :
60- model = apply_pruning_in_place (model )
6161 model = apply_compression_in_place (model )
6262 else :
6363 msg = f"Unsupported strip format: { strip_format } "
@@ -109,6 +109,7 @@ def replace_quantizer_to_torch_native_module(model: TModel, graph: NNCFGraph) ->
109109 return model
110110
111111
112+ @torch .no_grad ()
112113def replace_quantizer_to_compressed_weight_with_decompressor (model : TModel ) -> TModel :
113114 """
114115 Performs transformation from fake quantize format (FQ) to dequantization one (DQ):
@@ -136,12 +137,11 @@ def replace_quantizer_to_compressed_weight_with_decompressor(model: TModel) -> T
136137 module = get_module_by_name (module_name , model )
137138 weight_param = getattr (module , weight_attr_name )
138139
139- with torch .no_grad ():
140- if isinstance (hook_module , AsymmetricQuantizer ):
141- decompressor , q_weight = asym_fq_to_decompressor (hook_module , weight_param )
142- else :
143- decompressor , q_weight = sym_fq_to_decompressor (hook_module , weight_param ) # type: ignore[assignment]
144- packed_tensor = decompressor .pack_weight (q_weight )
140+ if isinstance (hook_module , AsymmetricQuantizer ):
141+ decompressor , q_weight = asym_fq_to_decompressor (hook_module , weight_param )
142+ else :
143+ decompressor , q_weight = sym_fq_to_decompressor (hook_module , weight_param ) # type: ignore[assignment]
144+ packed_tensor = decompressor .pack_weight (q_weight )
145145
146146 weight_param .requires_grad = False
147147 weight_param .data = packed_tensor
@@ -150,40 +150,46 @@ def replace_quantizer_to_compressed_weight_with_decompressor(model: TModel) -> T
150150 return model
151151
152152
153+ @torch .no_grad ()
153154def apply_compression_in_place (model : TModel ) -> TModel :
154155 """
155- Applies fake quantizers in-place to the weights:
156- (weights + FQ ) -> (fake quantized weights)
156+ Applies NNCF module in-place to the weights:
157+ (weights + NNCF module ) -> (in-place compressed weights)
157158
158159 :param model: Compressed model
159- :param graph: The model graph.
160160 :return: The modified NNCF network.
161161 """
162162 hook_storage = get_hook_storage (model )
163-
164163 hooks_to_delete = []
165164 for hook_name , hook_module in hook_storage .named_hooks ():
166- if not isinstance (hook_module , (SymmetricQuantizer , AsymmetricQuantizer , BaseWeightsDecompressor )):
165+ if not isinstance (
166+ hook_module ,
167+ (RBPruningMask , UnstructuredPruningMask , SymmetricQuantizer , AsymmetricQuantizer , BaseWeightsDecompressor ),
168+ ):
167169 continue
170+
168171 hook_module .eval ()
172+ hook_type , op_name , port_id = decode_hook_name (hook_name )
173+ if hook_type != "post_hooks" or port_id != 0 :
174+ msg = f"Unexpected place of Compression Module: { hook_type = } , { op_name = } , { port_id = } "
175+ raise nncf .InternalError (msg )
169176
170- _ , op_name , _ = decode_hook_name (hook_name )
171177 module_name , weight_attr_name = split_const_name (op_name )
172178 module = get_module_by_name (module_name , model )
173179 weight_param = getattr (module , weight_attr_name )
174180
175- with torch .no_grad ():
176- if isinstance (hook_module , (SymmetricQuantizer , AsymmetricQuantizer )):
177- fq_weight = hook_module .quantize (weight_param )
178- else :
179- fq_weight = hook_module (weight_param )
181+ if not isinstance (weight_param , torch .nn .Parameter ):
182+ msg = f"Expected torch.nn.Parameter under { op_name } , got { type (weight_param )} ."
183+ raise nncf .InternalError (msg )
180184
181185 weight_param .requires_grad = False
182- weight_param .data = fq_weight
186+ if isinstance (hook_module , (SymmetricQuantizer , AsymmetricQuantizer )):
187+ weight_param .data = hook_module .quantize (weight_param )
188+ else :
189+ weight_param .data = hook_module (weight_param )
183190
184191 hooks_to_delete .append (hook_name )
185192
186193 for hook_name in hooks_to_delete :
187194 hook_storage .delete_hook (hook_name )
188-
189195 return model
0 commit comments