Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions backends/nxp/backend/edge_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@
]


def _is_dequantize(node_: Node) -> bool:
return node_.op == "call_function" and node_.target in DEQUANTIZE_OPERATORS


def _is_quantize(node_: Node) -> bool:
return node_.op == "call_function" and node_.target in QUANTIZE_OPERATORS


def input_tensor(node: Node, input_index: int) -> torch.Tensor:
if len(node.all_input_nodes) <= input_index:
raise IndexError
Expand Down Expand Up @@ -103,3 +95,33 @@ def try_get_tensor_constant_from_node(
return None
attr_itr = getattr(attr_itr, atom)
return attr_itr


def _is_dequantize(node_: Node) -> bool:
return node_.op == "call_function" and node_.target in [
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]


def _is_quantize(node_: Node) -> bool:
return node_.op == "call_function" and node_.target in [
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_channel.default,
]


def previous_non_qdq_node(node: Node, input_index: int = 0) -> Node | None:
"""Return the first node which is not a `quantize` or `dequantize`, found by traversing the graph backwards
starting with the `node.args[input_index]`,
"""
current_node = node.args[input_index]
while True:
if _is_quantize(current_node) or _is_dequantize(current_node):
current_node = current_node.args[0]
else:
return current_node
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from executorch.backends.nxp.backend.custom_delegation_options import (
CustomDelegationOptions,
)
from executorch.backends.nxp.backend.edge_helper import previous_non_qdq_node
from executorch.backends.nxp.backend.ir.converter.conversion import translator
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
apply_permutation_to,
create_channels_first_to_channels_last_permutation,
)
from executorch.backends.nxp.backend.ir.converter.node_converter import (
Expand All @@ -23,6 +25,7 @@
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
from torch.fx import Node
from torch.fx.passes.infra.partitioner import Partition
from torch.nn import Parameter


Expand Down Expand Up @@ -85,10 +88,6 @@ def _is_supported_on_target(

dim = CatConverter._get_normalized_dim(node)

# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1491
if dim == 0:
return False

# Neutron requires the channels to be a multiple of `num_macs`. The channels could either be the second or the
# last dimension, depending on the formats of the node.
if node.meta[NXP_NODE_FORMAT].is_channels_first():
Expand Down Expand Up @@ -151,6 +150,46 @@ def _is_supported_in_IR(

return True

@classmethod
def supports_partitioning_result(
cls,
node: Node,
partition_list: list[Partition],
custom_delegation_options: CustomDelegationOptions,
):
# There is a bug in the NeutronConverter, where if none of the input dimensions before the one referenced by
# `dim` are `!= 1`, the `Concat` is not delegated.
# This only happens when the inputs to the `Concat` are model inputs, and not outputs of other
# operators.
cat_partition = [p for p in partition_list if node in p.nodes][0]
cat_inputs = map(previous_non_qdq_node, node.args[0])

if not all(
input_.op == "call_function" and input_ in cat_partition.nodes
for input_ in cat_inputs
):
# Some inputs of the `cat` are NOT in the same partition as `cat`.
dim = CatConverter._get_normalized_dim(node)
input_shapes = [list(n.meta["val"].shape) for n in node.args[0]]
if node.meta[NXP_NODE_FORMAT].is_channels_first():
# Transform the shapes to channels last.
to_nhwc_perm = create_channels_first_to_channels_last_permutation(
len(node.meta["val"].shape), True
)
input_shapes = [
apply_permutation_to(shape, to_nhwc_perm) for shape in input_shapes
]

# Transform the `dim` to refer to a channels last dimension.
dim = to_nhwc_perm.index(dim)

for input_shape in input_shapes:
if not any(d != 1 for d in input_shape[:dim]):
# Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension.
return False

return True

def convert(self, node: Node):
"""Convert the 'aten.cat' operator to TFLite 'Concatenation'."""
self.assert_convertible(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def forward(self, *inputs: torch.Tensor):
return torch.cat(list(inputs), self.dim)


class AddCatModule(torch.nn.Module):

def __init__(self, dim: int):
super().__init__()
self.dim = dim

def forward(self, *inputs: torch.Tensor):
inputs = [input_ + input_ for input_ in inputs]

return torch.cat(list(inputs), self.dim)


class CatConvModule(torch.nn.Module):

def __init__(self, dim: int, channels: int = 4):
Expand Down Expand Up @@ -73,7 +85,7 @@ def forward(self, *inputs: torch.Tensor):
],
)
def test_cat__same_shapes(dim, num_inputs, rank, mocker):
input_shape = tuple([2, 8, 8, 8, 8][-rank:])
input_shape = tuple([8, 8, 8, 8][:rank])

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

Expand Down Expand Up @@ -134,11 +146,23 @@ def test_cat__channels_first__same_shapes(dim, num_inputs, mocker):
)


@pytest.mark.parametrize("dim", [0, -4])
@pytest.mark.parametrize("num_inputs", [2])
def test_cat__unsupported_dim__imxrt700(dim, num_inputs):
input_shape = (2, 8, 6, 8)

@pytest.mark.parametrize(
"dim, input_shape",
[
pytest.param(0, (1, 8, 8, 8), id="axis = 0"),
pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."),
pytest.param(-4, (1, 8, 8, 8), id="axis = -4"),
pytest.param(1, (1, 1, 8, 8), id="axis = 1"),
pytest.param(-3, (1, 1, 8, 8), id="axis = -3"),
pytest.param(2, (1, 1, 1, 8), id="axis = 2"),
pytest.param(-2, (1, 1, 1, 8), id="axis = -2"),
],
)
def test_cat__unsupported__imxrt700(dim, input_shape):
"""This test is conjoined with the one below (`test_cat__context_dependent__imxrt700`).
In this case, the inputs of the `cat` are NOT compute ops, so the `cat` is NOT delegated.
"""
num_inputs = 2
quantized_program = to_quantized_edge_program(
CatModule(dim), [input_shape] * num_inputs, target="imxrt700"
).exported_program()
Expand All @@ -152,6 +176,32 @@ def test_cat__unsupported_dim__imxrt700(dim, num_inputs):
)


@pytest.mark.parametrize(
"dim, input_shape",
[
pytest.param(0, (1, 8, 8, 8), id="axis = 0"),
pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."),
pytest.param(-4, (1, 8, 8, 8), id="axis = -4"),
pytest.param(1, (1, 1, 8, 8), id="axis = 1"),
pytest.param(-3, (1, 1, 8, 8), id="axis = -3"),
pytest.param(2, (1, 1, 1, 8), id="axis = 2"),
pytest.param(-2, (1, 1, 1, 8), id="axis = -2"),
],
)
def test_cat__context_dependent__imxrt700(dim, input_shape):
"""This test is conjoined with the one above (`test_cat__unsupported__imxrt700`).
In this case, the inputs of the `cat` are compute ops, so the `cat` is delegated.
"""
num_inputs = 2
ep = to_quantized_edge_program(
AddCatModule(dim), [input_shape] * num_inputs, target="imxrt700"
).exported_program()

# Make sure the `Cat` was delegated.
assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.cat.default])
assert any("lowered_module" in node.name for node in ep.graph.nodes)


@pytest.mark.parametrize(
"rank, num_inputs, dim",
[
Expand Down
Loading