2929from nncf .onnx .graph .onnx_helper import get_tensor
3030from nncf .onnx .graph .transformations .commands import ONNXInitializerUpdateCommand
3131from nncf .onnx .graph .transformations .commands import ONNXModelExtractionCommand
32+ from nncf .onnx .graph .transformations .commands import ONNXMultiplyInsertionCommand
3233from nncf .onnx .graph .transformations .commands import ONNXOutputInsertionCommand
3334from nncf .onnx .graph .transformations .commands import ONNXQDQNodeRemovingCommand
3435from nncf .onnx .graph .transformations .commands import ONNXQuantizerInsertionCommand
@@ -91,6 +92,7 @@ def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelPr
9192 initializer_update_transformations = []
9293 qdq_node_removing_transformations = []
9394 model_extraction_transformation = None
95+ multiply_insert_transformations = []
9496 transformations = transformation_layout .transformations
9597 # No transformation applied
9698 if not transformations :
@@ -106,15 +108,24 @@ def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelPr
106108 qdq_node_removing_transformations .append (transformation )
107109 elif isinstance (transformation , ONNXInitializerUpdateCommand ):
108110 initializer_update_transformations .append (transformation )
111+ elif isinstance (transformation , ONNXMultiplyInsertionCommand ):
112+ multiply_insert_transformations .append (transformation )
109113 # Inplace transformations, using deepcopy of model
110- if quantizer_insert_transformations or initializer_update_transformations or qdq_node_removing_transformations :
114+ if (
115+ quantizer_insert_transformations
116+ or initializer_update_transformations
117+ or qdq_node_removing_transformations
118+ or multiply_insert_transformations
119+ ):
111120 model = deepcopy (self ._model )
112121 if quantizer_insert_transformations :
113122 model = self ._apply_quantizer_insertion_transformations (model , quantizer_insert_transformations )
114123 if qdq_node_removing_transformations :
115124 model = self ._apply_qdq_node_removing_transformations (model , qdq_node_removing_transformations )
116125 if initializer_update_transformations :
117126 model = self ._apply_initializer_update_transformations (model , initializer_update_transformations )
127+ if multiply_insert_transformations :
128+ model = self ._apply_multiply_insertion_transformations (model , multiply_insert_transformations )
118129 # Transformations that create new model
119130 if output_insert_transformations :
120131 model = self ._apply_output_insertion_transformations (output_insert_transformations )
@@ -459,6 +470,48 @@ def _apply_qdq_node_removing_transformations(
459470
460471 return model
461472
473+ @staticmethod
474+ def _apply_multiply_insertion_transformations (
475+ model : onnx .ModelProto , transformations : list [ONNXMultiplyInsertionCommand ]
476+ ) -> onnx .ModelProto :
477+ """
478+ Inserts Multiply with provided value for corresponding layer.
479+
480+ :param transformations: List of the smooth insertion transformations.
481+ :returns: Transformed model with Multiply nodes.
482+ """
483+ node_name_to_node = get_name_to_node_map (model )
484+
485+ for transformation in transformations :
486+ target_node_name = transformation .target_point .target_node_name
487+ target_output_port = transformation .target_point .port_id
488+ target_node = node_name_to_node [target_node_name ]
489+ output_tensor_name = target_node .output [target_output_port ]
490+
491+ # Create a new initializer for the scale constant
492+ scale_tensor_name = f"{ transformation .multiply_node_name } _scale"
493+ scale_tensor = onnx .numpy_helper .from_array (transformation .scale_value , name = scale_tensor_name )
494+ model .graph .initializer .append (scale_tensor )
495+
496+ # Create a new Multiply node
497+ mul_output_name = f"{ transformation .multiply_node_name } _output"
498+ mul_node = onnx .helper .make_node (
499+ "Mul" ,
500+ inputs = [output_tensor_name , scale_tensor_name ],
501+ outputs = [mul_output_name ],
502+ name = transformation .multiply_node_name ,
503+ )
504+ target_index = get_node_index (model , target_node_name )
505+ model .graph .node .insert (target_index + 1 , mul_node )
506+
507+ for name in transformation .destination_node_names :
508+ node = node_name_to_node [name ]
509+ for i , input_name in enumerate (node .input ):
510+ if input_name == output_tensor_name :
511+ node .input [i ] = mul_output_name
512+
513+ return model
514+
462515
463516def set_initializer (initializer_name : str , model : onnx .ModelProto , new_value : np .ndarray ) -> None :
464517 """
0 commit comments