Skip to content

Commit 9faf559

Browse files
[ONNX] Add MatMul for FBC/BC (#3657)
### Changes Add support for the `MatMul` operation in the Fast Bias/Bias Correction algorithm. ### Reason for changes The Fast Bias/Bias Correction algorithm is not applied to the `MatMul`->`Add` subgraph in which one of the inputs to the Add operation is a constant. In this case, the `MatMul` operation is not considered as having a bias, so the algorithm is not applied. ### Related tickets Ref: 135433 ### Tests - test_update_bias_in_matmul_add - text_examples: [build](https://github.com/openvinotoolkit/nncf/actions/runs/18529108467)
1 parent 0379ea1 commit 9faf559

File tree

13 files changed

+115
-11
lines changed

13 files changed

+115
-11
lines changed

src/nncf/onnx/graph/metatypes/groups.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139
OPERATIONS_WITH_BIAS_REDUCED = [
140140
onnx_metatypes.ONNXConvolutionMetatype,
141141
onnx_metatypes.ONNXGemmMetatype,
142-
# TODO: Need to add MatMul with the separate bias support (CVS-135433)
142+
onnx_metatypes.ONNXMatMulMetatype,
143143
]
144144

145145
OPERATIONS_WITH_BIAS = [

src/nncf/onnx/graph/metatypes/onnx_metatypes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ class ONNXMatMulMetatype(ONNXOpWithWeightsMetatype):
143143
op_names = ["MatMul"]
144144
hw_config_names = [HWConfigOpName.MATMUL]
145145
weight_channel_axis = -1 # For port_id=1
146-
bias_port_id = 2
147146
possible_weight_ports = [0, 1]
148147
output_channel_axis = -1
149148

src/nncf/onnx/graph/nncf_graph_builder.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_BIAS
2727
from nncf.onnx.graph.metatypes.groups import POSSIBLE_WEIGHT_LAYER_METATYPES
2828
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXGemmMetatype
29+
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXMatMulMetatype
2930
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpMetatype
3031
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype
3132
from nncf.onnx.graph.metatypes.onnx_metatypes import get_metatype
@@ -186,17 +187,54 @@ def _get_bias_attr(
186187
node: onnx.NodeProto,
187188
model: onnx.ModelProto,
188189
parents_node_mapping: dict[str, onnx.NodeProto],
190+
children_node_mapping: dict[str, onnx.NodeProto],
189191
) -> dict[str, str]:
190192
"""
191193
Returns bias tensor attributes.
192194
193195
:param node: ONNX node.
194196
:param model: ONNX model.
195197
:param parents_node_mapping: Mapping from edge name to node which outputs this edge.
198+
:param children_node_mapping: mapping from edge name to nodes which consume this edge as an input.
196199
:return: Bias tensor attributes.
197200
"""
198-
bias_attrs = {}
199201
metatype = get_metatype(model, node)
202+
203+
if metatype == ONNXMatMulMetatype:
204+
weight_port_ids = _get_weight_port_ids(node, model, parents_node_mapping)
205+
206+
if not weight_port_ids:
207+
# `node` is a MatMul without weights, so return empty attributes
208+
return {}
209+
210+
# Retrieve all nodes that consume the output of the MatMul operation.
211+
# The MatMul operation has only one output.
212+
y = node.output[0]
213+
consumers = children_node_mapping[y]
214+
215+
if len(consumers) != 1 or consumers[0].op_type != "Add":
216+
return {}
217+
218+
# Here, we are certain that after a `MatMul` operation, there is only
219+
# the `Add` operation.
220+
add_node = consumers[0]
221+
222+
# Find the input of `add_node` that is not equal to `y`.
223+
tensor_name = None
224+
port_id = None
225+
for i, name in enumerate(add_node.input):
226+
if name != y:
227+
tensor_name = name
228+
port_id = i
229+
break
230+
231+
# Ensure that `tensor_name` is the output of a `Constant` node or an initializer.
232+
initializer = {x.name: x for x in model.graph.initializer}
233+
if tensor_name in initializer or parents_node_mapping[tensor_name].op_type == "Constant":
234+
return {"node": add_node.name, "name": tensor_name, "port_id": port_id}
235+
return {}
236+
237+
bias_attrs = {}
200238
if _is_node_with_bias(node, model):
201239
bias_tensor_port_id = get_bias_tensor_port_id(metatype)
202240
bias_edge_name = get_tensor_edge_name(model, node, bias_tensor_port_id, parents_node_mapping)
@@ -348,6 +386,7 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
348386
"""
349387
onnx_model = GraphConverter._replace_empty_node_name(onnx_model)
350388
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
389+
351390
edge_info_mapping = get_edge_info_mapping(onnx_model)
352391
children_node_mapping = get_children_node_mapping(onnx_model)
353392
parents_node_mapping = get_parents_node_mapping(onnx_model)
@@ -358,7 +397,8 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
358397
is_shared = None
359398
weight_attrs = {}
360399
node_attrs = _get_node_attrs(node, onnx_model)
361-
bias_attrs = _get_bias_attr(node, onnx_model, parents_node_mapping)
400+
bias_attrs = _get_bias_attr(node, onnx_model, parents_node_mapping, children_node_mapping)
401+
362402
if weight_port_ids: # If node has weight
363403
weight_edge_names = []
364404
for weight_port_id in weight_port_ids:

src/nncf/onnx/graph/node_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def get_bias_value(node_with_bias: NNCFNode, model: onnx.ModelProto) -> np.ndarr
4545
:return: The bias value that is applied to the output tensor of the node's operation.
4646
"""
4747
assert node_with_bias.layer_attributes.has_bias()
48+
# TODO(andrey-churkin): Support Add + Constant case
4849
bias_name = node_with_bias.layer_attributes.bias_attrs["name"]
4950
return get_tensor_value(model, bias_name)
5051

src/nncf/onnx/graph/transformations/command_creation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,13 @@ def create_bias_correction_command(node: NNCFNode, bias_value: np.ndarray) -> ON
2929
:param bias_value: The new bias value that will be set.
3030
:return: The `ONNXInitializerUpdateCommand` command to update bias.
3131
"""
32-
bias_port_id = node.metatype.bias_port_id
33-
target_point = ONNXTargetPoint(TargetType.LAYER, node.node_name, bias_port_id)
32+
node_name = node.layer_attributes.bias_attrs.get("node")
33+
if node_name:
34+
port_id = node.layer_attributes.bias_attrs["port_id"]
35+
target_point = ONNXTargetPoint(TargetType.LAYER, node_name, port_id)
36+
else:
37+
bias_port_id = node.metatype.bias_port_id
38+
target_point = ONNXTargetPoint(TargetType.LAYER, node.node_name, bias_port_id)
3439
return ONNXInitializerUpdateCommand(target_point, bias_value)
3540

3641

src/nncf/quantization/algorithms/fast_bias_correction/algorithm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,11 @@ def apply(
175175

176176
output_channel_axis = node.metatype.output_channel_axis
177177
input_channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port_id, input_shape)
178-
if bias_value.ndim > 1:
179-
# Make index positive
180-
output_channel_axis = range(bias_value.ndim)[output_channel_axis]
181-
input_channel_axis = range(bias_value.ndim)[input_channel_axis]
178+
182179
input_blob = self._backend_entity.create_input_data(
183180
input_shape, input_fp, sub_input_name, input_channel_axis
184181
)
182+
185183
bias_shift = self._get_bias_shift(
186184
model=extracted_model,
187185
input_blob=input_blob,

src/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def get_sub_input_output_names(subgraph: onnx.ModelProto) -> tuple[str, str]:
6767
def create_input_data(
6868
shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int
6969
) -> dict[str, np.array]:
70+
channel_axis = range(len(shape))[channel_axis]
7071
blob = np.zeros(shape, dtype=data[0].data.dtype)
7172
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
7273
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))

src/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def get_sub_input_output_names(subgraph: ov.Model) -> tuple[str, str]:
7272
def create_input_data(
7373
shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int
7474
) -> dict[str, np.ndarray]:
75+
channel_axis = range(len(shape))[channel_axis]
7576
blob = np.zeros(shape, dtype=data[0].data.dtype)
7677
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
7778
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))

src/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def get_sub_input_output_names(subgraph: NNCFNetwork) -> tuple[Optional[str], Op
7878

7979
@staticmethod
8080
def create_input_data(shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int) -> torch.Tensor:
81+
channel_axis = range(len(shape))[channel_axis]
8182
blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device)
8283
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
8384
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))

src/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def get_sub_input_output_names(subgraph: torch.fx.GraphModule) -> tuple[Optional
6767

6868
@staticmethod
6969
def create_input_data(shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int) -> torch.Tensor:
70+
channel_axis = range(len(shape))[channel_axis]
7071
blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device)
7172
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
7273
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))

0 commit comments

Comments
 (0)