Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nncf/onnx/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
OPERATIONS_WITH_BIAS_REDUCED = [
onnx_metatypes.ONNXConvolutionMetatype,
onnx_metatypes.ONNXGemmMetatype,
# TODO: Need to add MatMul with the separate bias support (CVS-135433)
onnx_metatypes.ONNXMatMulMetatype,
]

OPERATIONS_WITH_BIAS = [
Expand Down
1 change: 0 additions & 1 deletion src/nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ class ONNXMatMulMetatype(ONNXOpWithWeightsMetatype):
op_names = ["MatMul"]
hw_config_names = [HWConfigOpName.MATMUL]
weight_channel_axis = -1 # For port_id=1
bias_port_id = 2
possible_weight_ports = [0, 1]
output_channel_axis = -1

Expand Down
49 changes: 46 additions & 3 deletions src/nncf/onnx/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_BIAS
from nncf.onnx.graph.metatypes.groups import POSSIBLE_WEIGHT_LAYER_METATYPES
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXGemmMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXMatMulMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import get_metatype
Expand All @@ -39,6 +40,7 @@
from nncf.onnx.graph.onnx_helper import get_output_port_id_for_node_before_output
from nncf.onnx.graph.onnx_helper import get_parents_node_mapping
from nncf.onnx.graph.onnx_helper import is_node_has_shared_weight
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference


class ONNXLayerAttributes(BaseLayerAttributes):
Expand Down Expand Up @@ -186,17 +188,55 @@ def _get_bias_attr(
node: onnx.NodeProto,
model: onnx.ModelProto,
parents_node_mapping: dict[str, onnx.NodeProto],
children_node_mapping: dict[str, onnx.NodeProto],
) -> dict[str, str]:
"""
Returns bias tensor attributes.

:param node: ONNX node.
:param model: ONNX model.
:param parents_node_mapping: Mapping from edge name to node which outputs this edge.
:param children_node_mapping: mapping from edge name to nodes which consume this edge as an input.
:return: Bias tensor attributes.
"""
bias_attrs = {}
metatype = get_metatype(model, node)

if metatype == ONNXMatMulMetatype:
weight_port_ids = _get_weight_port_ids(node, model, parents_node_mapping)

if not weight_port_ids:
# `node` is a MatMul without weights, so return empty attributes
return {}

# Retrieve all nodes that consume the output of the MatMul operation.
# The MatMul operation has only one output.
y = node.output[0]
consumers = children_node_mapping[y]

if len(consumers) != 1 or consumers[0].op_type != "Add":
return {}

# Here, we are certain that after a `MatMul` operation, there is only
# the `Add` operation.
add_node = consumers[0]

# Find the input of `add_node` that is not equal to `y`.
tensor_name = None
port_id = None
for i, name in enumerate(add_node.input):
if name != y:
tensor_name = name
port_id = i
break

# Ensure that `tensor_name` is the output of a `Constant` node or an initializer.
initializer = {x.name: x for x in model.graph.initializer}
if tensor_name in initializer or parents_node_mapping[tensor_name].op_type == "Constant":
return {"node": add_node.name, "name": tensor_name, "port_id": port_id}
else:
return {}

bias_attrs = {}
if _is_node_with_bias(node, model):
bias_tensor_port_id = get_bias_tensor_port_id(metatype)
bias_edge_name = get_tensor_edge_name(model, node, bias_tensor_port_id, parents_node_mapping)
Expand Down Expand Up @@ -347,7 +387,9 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
:return: NNCFGraph.
"""
onnx_model = GraphConverter._replace_empty_node_name(onnx_model)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
# onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
onnx_model = SymbolicShapeInference.infer_shapes(onnx_model)

edge_info_mapping = get_edge_info_mapping(onnx_model)
children_node_mapping = get_children_node_mapping(onnx_model)
parents_node_mapping = get_parents_node_mapping(onnx_model)
Expand All @@ -358,7 +400,8 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
is_shared = None
weight_attrs = {}
node_attrs = _get_node_attrs(node, onnx_model)
bias_attrs = _get_bias_attr(node, onnx_model, parents_node_mapping)
bias_attrs = _get_bias_attr(node, onnx_model, parents_node_mapping, children_node_mapping)

if weight_port_ids: # If node has weight
weight_edge_names = []
for weight_port_id in weight_port_ids:
Expand Down
1 change: 1 addition & 0 deletions src/nncf/onnx/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def get_bias_value(node_with_bias: NNCFNode, model: onnx.ModelProto) -> np.ndarr
:return: The bias value that is applied to the output tensor of the node's operation.
"""
assert node_with_bias.layer_attributes.has_bias()
# TODO(andrey-churkin): Support Add + Constant case
bias_name = node_with_bias.layer_attributes.bias_attrs["name"]
return get_tensor_value(model, bias_name)

Expand Down
9 changes: 7 additions & 2 deletions src/nncf/onnx/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ def create_bias_correction_command(node: NNCFNode, bias_value: np.ndarray) -> ON
:param bias_value: The new bias value that will be set.
:return: The `ONNXInitializerUpdateCommand` command to update bias.
"""
bias_port_id = node.metatype.bias_port_id
target_point = ONNXTargetPoint(TargetType.LAYER, node.node_name, bias_port_id)
node_name = node.layer_attributes.bias_attrs.get("node")
if node_name:
port_id = node.layer_attributes.bias_attrs["port_id"]
target_point = ONNXTargetPoint(TargetType.LAYER, node_name, port_id)
else:
bias_port_id = node.metatype.bias_port_id
target_point = ONNXTargetPoint(TargetType.LAYER, node.node_name, bias_port_id)
return ONNXInitializerUpdateCommand(target_point, bias_value)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,11 @@ def apply(

output_channel_axis = node.metatype.output_channel_axis
input_channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port_id, input_shape)
if bias_value.ndim > 1:
# Make index positive
output_channel_axis = range(bias_value.ndim)[output_channel_axis]
input_channel_axis = range(bias_value.ndim)[input_channel_axis]

input_blob = self._backend_entity.create_input_data(
input_shape, input_fp, sub_input_name, input_channel_axis
)

bias_shift = self._get_bias_shift(
model=extracted_model,
input_blob=input_blob,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def get_sub_input_output_names(subgraph: onnx.ModelProto) -> tuple[str, str]:
def create_input_data(
shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int
) -> dict[str, np.array]:
channel_axis = range(len(shape))[channel_axis]
blob = np.zeros(shape, dtype=data[0].data.dtype)
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_sub_input_output_names(subgraph: ov.Model) -> tuple[str, str]:
def create_input_data(
shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int
) -> dict[str, np.ndarray]:
channel_axis = range(len(shape))[channel_axis]
blob = np.zeros(shape, dtype=data[0].data.dtype)
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_sub_input_output_names(subgraph: NNCFNetwork) -> tuple[Optional[str], Op

@staticmethod
def create_input_data(shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int) -> torch.Tensor:
channel_axis = range(len(shape))[channel_axis]
blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device)
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def get_sub_input_output_names(subgraph: torch.fx.GraphModule) -> tuple[Optional

@staticmethod
def create_input_data(shape: tuple[int], data: list[Tensor], input_name: str, channel_axis: int) -> torch.Tensor:
channel_axis = range(len(shape))[channel_axis]
blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device)
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))
Expand Down
2 changes: 1 addition & 1 deletion tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ timm/deit3_small_patch16_224_backend_CUDA_TORCH:
timm/deit3_small_patch16_224_backend_FP32:
metric_value: 0.81358
timm/deit3_small_patch16_224_backend_ONNX:
metric_value: 0.81116
metric_value: 0.81156
timm/deit3_small_patch16_224_backend_OV:
metric_value: 0.81276
timm/deit3_small_patch16_224_backend_TORCH:
Expand Down
Loading