diff --git a/src/nncf/onnx/graph/model_transformer.py b/src/nncf/onnx/graph/model_transformer.py index 889cf843a59..540ea1eb9c5 100644 --- a/src/nncf/onnx/graph/model_transformer.py +++ b/src/nncf/onnx/graph/model_transformer.py @@ -29,6 +29,7 @@ from nncf.onnx.graph.onnx_helper import get_tensor from nncf.onnx.graph.transformations.commands import ONNXInitializerUpdateCommand from nncf.onnx.graph.transformations.commands import ONNXModelExtractionCommand +from nncf.onnx.graph.transformations.commands import ONNXMultiplyInsertionCommand from nncf.onnx.graph.transformations.commands import ONNXOutputInsertionCommand from nncf.onnx.graph.transformations.commands import ONNXQDQNodeRemovingCommand from nncf.onnx.graph.transformations.commands import ONNXQuantizerInsertionCommand @@ -91,6 +92,7 @@ def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelPr initializer_update_transformations = [] qdq_node_removing_transformations = [] model_extraction_transformation = None + multiply_insert_transformations = [] transformations = transformation_layout.transformations # No transformation applied if not transformations: @@ -106,8 +108,15 @@ def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelPr qdq_node_removing_transformations.append(transformation) elif isinstance(transformation, ONNXInitializerUpdateCommand): initializer_update_transformations.append(transformation) + elif isinstance(transformation, ONNXMultiplyInsertionCommand): + multiply_insert_transformations.append(transformation) # Inplace transformations, using deepcopy of model - if quantizer_insert_transformations or initializer_update_transformations or qdq_node_removing_transformations: + if ( + quantizer_insert_transformations + or initializer_update_transformations + or qdq_node_removing_transformations + or multiply_insert_transformations + ): model = deepcopy(self._model) if quantizer_insert_transformations: model = self._apply_quantizer_insertion_transformations(model, quantizer_insert_transformations) @@ -115,6 +124,8 @@ def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelPr model = self._apply_qdq_node_removing_transformations(model, qdq_node_removing_transformations) if initializer_update_transformations: model = self._apply_initializer_update_transformations(model, initializer_update_transformations) + if multiply_insert_transformations: + model = self._apply_multiply_insertion_transformations(model, multiply_insert_transformations) # Transformations that create new model if output_insert_transformations: model = self._apply_output_insertion_transformations(output_insert_transformations) @@ -459,6 +470,48 @@ def _apply_qdq_node_removing_transformations( return model + @staticmethod + def _apply_multiply_insertion_transformations( + model: onnx.ModelProto, transformations: list[ONNXMultiplyInsertionCommand] + ) -> onnx.ModelProto: + """ + Inserts Multiply with provided value for corresponding layer. + + :param transformations: List of the smooth insertion transformations. + :returns: Transformed model with Multiply nodes. + """ + node_name_to_node = get_name_to_node_map(model) + + for transformation in transformations: + target_node_name = transformation.target_point.target_node_name + target_output_port = transformation.target_point.port_id + target_node = node_name_to_node[target_node_name] + output_tensor_name = target_node.output[target_output_port] + + # Create a new initializer for the scale constant + scale_tensor_name = f"{transformation.multiply_node_name}_scale" + scale_tensor = onnx.numpy_helper.from_array(transformation.scale_value, name=scale_tensor_name) + model.graph.initializer.append(scale_tensor) + + # Create a new Multiply node + mul_output_name = f"{transformation.multiply_node_name}_output" + mul_node = onnx.helper.make_node( + "Mul", + inputs=[output_tensor_name, scale_tensor_name], + outputs=[mul_output_name], + name=transformation.multiply_node_name, + ) + target_index = get_node_index(model, target_node_name) + model.graph.node.insert(target_index + 1, mul_node) + + for name in transformation.destination_node_names: + node = node_name_to_node[name] + for i, input_name in enumerate(node.input): + if input_name == output_tensor_name: + node.input[i] = mul_output_name + + return model + def set_initializer(initializer_name: str, model: onnx.ModelProto, new_value: np.ndarray) -> None: """ diff --git a/src/nncf/onnx/graph/transformations/command_creation.py b/src/nncf/onnx/graph/transformations/command_creation.py index 933a6989c37..b83077c2cb0 100644 --- a/src/nncf/onnx/graph/transformations/command_creation.py +++ b/src/nncf/onnx/graph/transformations/command_creation.py @@ -16,6 +16,7 @@ from nncf.common.graph.transformations.command_creation import CommandCreator from nncf.common.graph.transformations.commands import TargetType from nncf.onnx.graph.transformations.commands import ONNXInitializerUpdateCommand +from nncf.onnx.graph.transformations.commands import ONNXMultiplyInsertionCommand from nncf.onnx.graph.transformations.commands import ONNXQDQNodeRemovingCommand from nncf.onnx.graph.transformations.commands import ONNXTargetPoint @@ -59,3 +60,15 @@ def create_command_to_update_weight( @staticmethod def create_command_to_insert_bias(node_without_bias, bias_value): raise NotImplementedError + + @staticmethod + def multiply_insertion_command( + source_node: NNCFNode, + destination_nodes: list[NNCFNode], + source_out_port: int, + scale_value: np.ndarray, + multiply_node_name: str, + ) -> ONNXMultiplyInsertionCommand: + target_point = ONNXTargetPoint(TargetType.POST_LAYER_OPERATION, source_node.node_name, source_out_port) + destination_node_names = [d.node_name for d in destination_nodes] + return ONNXMultiplyInsertionCommand(target_point, scale_value, destination_node_names, multiply_node_name) diff --git a/src/nncf/onnx/graph/transformations/commands.py b/src/nncf/onnx/graph/transformations/commands.py index 03dfa792704..b3bd08a48cd 100644 --- a/src/nncf/onnx/graph/transformations/commands.py +++ b/src/nncf/onnx/graph/transformations/commands.py @@ -113,3 +113,27 @@ def __init__(self, target_point: ONNXTargetPoint): :param target_point: The TargetPoint instance for the layer that contains information for removing. """ super().__init__(TransformationType.REMOVE, target_point) + + +class ONNXMultiplyInsertionCommand(ONNXInsertionCommand): + """ + Inserts Multiply nodes before the corresponding nodes. + """ + + def __init__( + self, + target_point: ONNXTargetPoint, + scale_value: np.ndarray, + destination_node_names: list[str], + multiply_node_name: str, + ): + """ + :param target_point: The TargetPoint instance for the insertion that contains layer's information. + :param scale_value: Scale value for Multiply layer. + :param destination_node_names: New layer consumers. + :param multiply_node_name: New layer name. + """ + super().__init__(target_point, None) + self.scale_value = scale_value + self.destination_node_names = destination_node_names + self.multiply_node_name = multiply_node_name diff --git a/src/nncf/quantization/algorithms/weight_compression/awq.py b/src/nncf/quantization/algorithms/weight_compression/awq.py index 38088179990..210b3abb5da 100644 --- a/src/nncf/quantization/algorithms/weight_compression/awq.py +++ b/src/nncf/quantization/algorithms/weight_compression/awq.py @@ -114,6 +114,10 @@ def _set_backend_entity( from nncf.quantization.algorithms.weight_compression.torch_fx_backend import FXAWQAlgoAlgoBackend self._backend_entity = FXAWQAlgoAlgoBackend() + elif model_backend == BackendType.ONNX: + from nncf.quantization.algorithms.weight_compression.onnx_backend import ONNXAWQAlgoAlgoBackend + + self._backend_entity = ONNXAWQAlgoAlgoBackend(model) else: msg = f"Cannot return backend-specific AWQ entity because {model_backend.value} is not supported!" raise nncf.UnsupportedBackendError(msg) diff --git a/src/nncf/quantization/algorithms/weight_compression/onnx_backend.py b/src/nncf/quantization/algorithms/weight_compression/onnx_backend.py index 2353641b258..a0d4411401c 100644 --- a/src/nncf/quantization/algorithms/weight_compression/onnx_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/onnx_backend.py @@ -30,6 +30,8 @@ from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic +from nncf.onnx.graph.metatypes import onnx_metatypes +from nncf.onnx.graph.metatypes.groups import ATOMIC_ACTIVATIONS_OPERATIONS from nncf.onnx.graph.metatypes.groups import CONVOLUTION_METATYPES from nncf.onnx.graph.metatypes.groups import MATMUL_METATYPES from nncf.onnx.graph.model_transformer import remove_initializer @@ -43,11 +45,14 @@ from nncf.onnx.graph.onnx_helper import get_tensor_value from nncf.onnx.graph.onnx_helper import pack_4_bits from nncf.onnx.graph.onnx_helper import pack_int4_to_uint8 +from nncf.onnx.graph.transformations.command_creation import ONNXCommandCreator from nncf.onnx.graph.transformations.commands import ONNXTargetPoint from nncf.onnx.quantization.ignored_patterns import create_rope from nncf.parameters import CompressionFormat from nncf.parameters import CompressWeightsMode from nncf.quantization.advanced_parameters import AdvancedCompressionParameters +from nncf.quantization.algorithms.weight_compression.awq_patterns import get_awq_patterns +from nncf.quantization.algorithms.weight_compression.backend import AWQAlgoBackend from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.lora_correction import LoraCorrectionAlgorithm @@ -181,7 +186,7 @@ def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNC def set_weight( self, node_with_weight: NNCFNode, weight_port_id: int, model: onnx.ModelProto, graph: NNCFGraph, weight: Tensor ): - node = self.name_to_node_map[node_with_weight.target_node_name] + node = self.name_to_node_map[node_with_weight.node_name] initializer_name = node.input[weight_port_id] set_initializer(initializer_name, model, weight.data) @@ -464,3 +469,19 @@ def _replace_matmul_with_matmulnbits( @staticmethod def get_ignored_patterns() -> GraphPattern: return create_rope() + + +class ONNXAWQAlgoAlgoBackend(AWQAlgoBackend, ONNXWeightCompressionAlgoBackend): + @staticmethod + def get_awq_patterns() -> dict[str, Callable]: + return get_awq_patterns( + onnx_metatypes.ONNXMatMulMetatype, onnx_metatypes.ONNXMulLayerMetatype, ATOMIC_ACTIVATIONS_OPERATIONS + ) + + @staticmethod + def scale_insertion_command( + source_node: NNCFNode, next_nodes: list[NNCFNode], source_node_output_port: int, scale: np.ndarray + ): + return ONNXCommandCreator.multiply_insertion_command( + source_node, next_nodes, source_node_output_port, scale, f"{source_node.node_name}/awq_mul" + ) diff --git a/src/nncf/quantization/quantize_model.py b/src/nncf/quantization/quantize_model.py index 340f5983f2b..98c2feaa7cb 100644 --- a/src/nncf/quantization/quantize_model.py +++ b/src/nncf/quantization/quantize_model.py @@ -634,7 +634,6 @@ def compress_weights( raise nncf.ParameterNotSupportedError(msg) options = { - "awq": awq, "scale_estimation": scale_estimation, "gptq": gptq, "lora_correction": lora_correction, diff --git a/tests/post_training/data/wc_reference_data.yaml b/tests/post_training/data/wc_reference_data.yaml index 6fe2d270d3f..79ed077d543 100644 --- a/tests/post_training/data/wc_reference_data.yaml +++ b/tests/post_training/data/wc_reference_data.yaml @@ -118,3 +118,7 @@ tinyllama_data_free_awq_backend_TORCH: metric_value: 0.85466 num_int4: 94 num_int8: 124 +tinyllama_data_free_awq_backend_ONNX: + metric_value: 0.82562 + num_int4: 264 + num_int8: 84 diff --git a/tests/post_training/data/wc_test_durations.json b/tests/post_training/data/wc_test_durations.json index 15b837cbe4c..f02c4cee7dd 100644 --- a/tests/post_training/data/wc_test_durations.json +++ b/tests/post_training/data/wc_test_durations.json @@ -15,5 +15,6 @@ "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_awq_backend_OV]": 164, "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_awq_backend_TORCH]": 210, "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_backend_ONNX]": 182, - "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_opset19_backend_ONNX]": 512 + "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_opset19_backend_ONNX]": 512, + "tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_awq_backend_ONNX]": 154 } diff --git a/tests/post_training/model_scope.py b/tests/post_training/model_scope.py index 1fa58c8470b..89c84398b9a 100644 --- a/tests/post_training/model_scope.py +++ b/tests/post_training/model_scope.py @@ -579,7 +579,7 @@ ), }, # TODO: (andreyanufr) add torch.fx backend - "backends": [BackendType.OV, BackendType.TORCH], + "backends": [BackendType.OV, BackendType.TORCH, BackendType.ONNX], }, ]