diff --git a/src/nncf/onnx/graph/metatypes/groups.py b/src/nncf/onnx/graph/metatypes/groups.py index d4291b18175..758b61eee08 100644 --- a/src/nncf/onnx/graph/metatypes/groups.py +++ b/src/nncf/onnx/graph/metatypes/groups.py @@ -139,7 +139,7 @@ OPERATIONS_WITH_BIAS_REDUCED = [ onnx_metatypes.ONNXConvolutionMetatype, onnx_metatypes.ONNXGemmMetatype, - # TODO: Need to add MatMul with the separate bias support (CVS-135433) + onnx_metatypes.ONNXMatMulMetatype, ] OPERATIONS_WITH_BIAS = [ diff --git a/src/nncf/onnx/graph/metatypes/onnx_metatypes.py b/src/nncf/onnx/graph/metatypes/onnx_metatypes.py index 1cf4024005d..98dc927f629 100644 --- a/src/nncf/onnx/graph/metatypes/onnx_metatypes.py +++ b/src/nncf/onnx/graph/metatypes/onnx_metatypes.py @@ -143,7 +143,6 @@ class ONNXMatMulMetatype(ONNXOpWithWeightsMetatype): op_names = ["MatMul"] hw_config_names = [HWConfigOpName.MATMUL] weight_channel_axis = -1 # For port_id=1 - bias_port_id = 2 possible_weight_ports = [0, 1] output_channel_axis = -1 diff --git a/src/nncf/onnx/graph/nncf_graph_builder.py b/src/nncf/onnx/graph/nncf_graph_builder.py index 50271475bef..9aaae598983 100644 --- a/src/nncf/onnx/graph/nncf_graph_builder.py +++ b/src/nncf/onnx/graph/nncf_graph_builder.py @@ -26,6 +26,7 @@ from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_BIAS from nncf.onnx.graph.metatypes.groups import POSSIBLE_WEIGHT_LAYER_METATYPES from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXGemmMetatype +from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXMatMulMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import get_metatype @@ -39,6 +40,7 @@ from nncf.onnx.graph.onnx_helper import get_output_port_id_for_node_before_output from nncf.onnx.graph.onnx_helper import get_parents_node_mapping from nncf.onnx.graph.onnx_helper import is_node_has_shared_weight +from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference class ONNXLayerAttributes(BaseLayerAttributes): @@ -186,6 +188,7 @@ def _get_bias_attr( node: onnx.NodeProto, model: onnx.ModelProto, parents_node_mapping: dict[str, onnx.NodeProto], + children_node_mapping: dict[str, onnx.NodeProto], ) -> dict[str, str]: """ Returns bias tensor attributes. @@ -193,10 +196,47 @@ def _get_bias_attr( :param node: ONNX node. :param model: ONNX model. :param parents_node_mapping: Mapping from edge name to node which outputs this edge. + :param children_node_mapping: mapping from edge name to nodes which consume this edge as an input. :return: Bias tensor attributes. """ - bias_attrs = {} metatype = get_metatype(model, node) + + if metatype == ONNXMatMulMetatype: + weight_port_ids = _get_weight_port_ids(node, model, parents_node_mapping) + + if not weight_port_ids: + # `node` is a MatMul without weights, so return empty attributes + return {} + + # Retrieve all nodes that consume the output of the MatMul operation. + # The MatMul operation has only one output. + y = node.output[0] + consumers = children_node_mapping[y] + + if len(consumers) != 1 or consumers[0].op_type != "Add": + return {} + + # Here, we are certain that after a `MatMul` operation, there is only + # the `Add` operation. + add_node = consumers[0] + + # Find the input of `add_node` that is not equal to `y`. + tensor_name = None + port_id = None + for i, name in enumerate(add_node.input): + if name != y: + tensor_name = name + port_id = i + break + + # Ensure that `tensor_name` is the output of a `Constant` node or an initializer. + initializer = {x.name: x for x in model.graph.initializer} + if tensor_name in initializer or parents_node_mapping[tensor_name].op_type == "Constant": + return {"node": add_node.name, "name": tensor_name, "port_id": port_id} + else: + return {} + + bias_attrs = {} if _is_node_with_bias(node, model): bias_tensor_port_id = get_bias_tensor_port_id(metatype) bias_edge_name = get_tensor_edge_name(model, node, bias_tensor_port_id, parents_node_mapping) @@ -347,7 +387,9 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph: :return: NNCFGraph. """ onnx_model = GraphConverter._replace_empty_node_name(onnx_model) - onnx_model = onnx.shape_inference.infer_shapes(onnx_model) + # onnx_model = onnx.shape_inference.infer_shapes(onnx_model) + onnx_model = SymbolicShapeInference.infer_shapes(onnx_model) + edge_info_mapping = get_edge_info_mapping(onnx_model) children_node_mapping = get_children_node_mapping(onnx_model) parents_node_mapping = get_parents_node_mapping(onnx_model) @@ -358,7 +400,8 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph: is_shared = None weight_attrs = {} node_attrs = _get_node_attrs(node, onnx_model) - bias_attrs = _get_bias_attr(node, onnx_model, parents_node_mapping) + bias_attrs = _get_bias_attr(node, onnx_model, parents_node_mapping, children_node_mapping) + if weight_port_ids: # If node has weight weight_edge_names = [] for weight_port_id in weight_port_ids: diff --git a/src/nncf/onnx/graph/node_utils.py b/src/nncf/onnx/graph/node_utils.py index 28caaf29f77..8720ec0b4b5 100644 --- a/src/nncf/onnx/graph/node_utils.py +++ b/src/nncf/onnx/graph/node_utils.py @@ -45,6 +45,7 @@ def get_bias_value(node_with_bias: NNCFNode, model: onnx.ModelProto) -> np.ndarr :return: The bias value that is applied to the output tensor of the node's operation. """ assert node_with_bias.layer_attributes.has_bias() + # TODO(andrey-churkin): Support Add + Constant case bias_name = node_with_bias.layer_attributes.bias_attrs["name"] return get_tensor_value(model, bias_name) diff --git a/src/nncf/onnx/graph/transformations/command_creation.py b/src/nncf/onnx/graph/transformations/command_creation.py index b83077c2cb0..c22c3e96820 100644 --- a/src/nncf/onnx/graph/transformations/command_creation.py +++ b/src/nncf/onnx/graph/transformations/command_creation.py @@ -29,8 +29,13 @@ def create_bias_correction_command(node: NNCFNode, bias_value: np.ndarray) -> ON :param bias_value: The new bias value that will be set. :return: The `ONNXInitializerUpdateCommand` command to update bias. """ - bias_port_id = node.metatype.bias_port_id - target_point = ONNXTargetPoint(TargetType.LAYER, node.node_name, bias_port_id) + node_name = node.layer_attributes.bias_attrs.get("node") + if node_name: + port_id = node.layer_attributes.bias_attrs["port_id"] + target_point = ONNXTargetPoint(TargetType.LAYER, node_name, port_id) + else: + bias_port_id = node.metatype.bias_port_id + target_point = ONNXTargetPoint(TargetType.LAYER, node.node_name, bias_port_id) return ONNXInitializerUpdateCommand(target_point, bias_value) diff --git a/src/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/src/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index b020c76a5e3..286ee423c92 100644 --- a/src/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/src/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -175,13 +175,11 @@ def apply( output_channel_axis = node.metatype.output_channel_axis input_channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port_id, input_shape) - if bias_value.ndim > 1: - # Make index positive - output_channel_axis = range(bias_value.ndim)[output_channel_axis] - input_channel_axis = range(bias_value.ndim)[input_channel_axis] + input_blob = self._backend_entity.create_input_data( input_shape, input_fp, sub_input_name, input_channel_axis ) + bias_shift = self._get_bias_shift( model=extracted_model, input_blob=input_blob, diff --git a/src/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py b/src/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py index 18925b6d0fa..2447d5b8509 100644 --- a/src/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py +++ b/src/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py @@ -67,6 +67,7 @@ def get_sub_input_output_names(subgraph: onnx.ModelProto) -> tuple[str, str]: def create_input_data( shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int ) -> dict[str, np.array]: + channel_axis = range(len(shape))[channel_axis] blob = np.zeros(shape, dtype=data[0].data.dtype) for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) diff --git a/src/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py b/src/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py index d0b90296d96..0fc0501056b 100644 --- a/src/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py +++ b/src/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py @@ -72,6 +72,7 @@ def get_sub_input_output_names(subgraph: ov.Model) -> tuple[str, str]: def create_input_data( shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int ) -> dict[str, np.ndarray]: + channel_axis = range(len(shape))[channel_axis] blob = np.zeros(shape, dtype=data[0].data.dtype) for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) diff --git a/src/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py b/src/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py index 847922662c0..1a77ab88365 100644 --- a/src/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py +++ b/src/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py @@ -78,6 +78,7 @@ def get_sub_input_output_names(subgraph: NNCFNetwork) -> tuple[Optional[str], Op @staticmethod def create_input_data(shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int) -> torch.Tensor: + channel_axis = range(len(shape))[channel_axis] blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device) for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) diff --git a/src/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py b/src/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py index 108bf904eff..c3518bc50e6 100644 --- a/src/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py +++ b/src/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py @@ -67,6 +67,7 @@ def get_sub_input_output_names(subgraph: torch.fx.GraphModule) -> tuple[Optional @staticmethod def create_input_data(shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int) -> torch.Tensor: + channel_axis = range(len(shape))[channel_axis] blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device) for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) diff --git a/tests/post_training/data/ptq_reference_data.yaml b/tests/post_training/data/ptq_reference_data.yaml index 20a5e68b4e5..47f51a685a0 100644 --- a/tests/post_training/data/ptq_reference_data.yaml +++ b/tests/post_training/data/ptq_reference_data.yaml @@ -139,7 +139,7 @@ timm/deit3_small_patch16_224_backend_CUDA_TORCH: timm/deit3_small_patch16_224_backend_FP32: metric_value: 0.81358 timm/deit3_small_patch16_224_backend_ONNX: - metric_value: 0.81116 + metric_value: 0.81156 timm/deit3_small_patch16_224_backend_OV: metric_value: 0.81276 timm/deit3_small_patch16_224_backend_TORCH: