Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 75 additions & 9 deletions backends/arm/_passes/annotate_channels_last_dim_order_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,50 @@
from typing import cast

import torch
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
)
from executorch.backends.arm.tosa_quant_utils import dq_op
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.library import impl, Library

# Define lib with passthrough operators. The operators have no real meaning in edge IR
# except for argument validaiton and a passthrough output. The operators will be used
# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect
# the edge IR graph but will be lowered to a TOSA-TRANSPOSE.
lib = Library("passthrough_to_tosa", "DEF")
# For operators that change the rank of the input, such as unsqueeze and squeeze, we may need
# to switch dim_order before the opertation. Changing tosa_dim_order is not sufficient
# as we also need transpose the data into the correct data format.
# By utilizing an edge IR passthrough operator we can keep the edge program in
# channels-first/contiguous and get the desired behavior in the TOSA lowering.
lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor")


@impl(lib, "_transpose")
def _transpose_impl(*args, **kwargs):
# Validate length of dim_order array
dim = args[1]
assert len(dim) <= 4
# Pass-through in edge-IR
return args[0]


class AnnotateChannelsLastDimOrder(ExportPass):
"""
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes.
The annotated tosa_dim_order is used to permute the node's shape such that it
gives a TOSA-compliant shape.
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
when a transition between 3D and 4D tensors happen.
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
"""

NHWC_order = (0, 2, 3, 1)
NHWC_inverse_order = (0, 3, 1, 2)
HWCM_order = (2, 3, 0, 1)

def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
"""
returns True for dq and w in the following sequences;
Expand All @@ -49,20 +79,56 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):

return False

def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target == exir_ops.edge.aten.squeeze_copy.dims:
input_node = node.args[0]
if input_node.meta["val"].dim() == 4:
with graph_module.graph.inserting_before(node):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose,
args=(input_node, list(self.NHWC_inverse_order)),
)
permute_node.meta["tosa_dim_order"] = tuple(
range(len(input_node.meta["val"].size()))
)
node.replace_input_with(input_node, permute_node)

if node.target == exir_ops.edge.aten.unsqueeze_copy.default:
if node.meta["val"].dim() == 4:
with graph_module.graph.inserting_after(node):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose,
args=(node, list(self.NHWC_order)),
)
permute_node.meta["tosa_dim_order"] = self.NHWC_order
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
users = [user for user in node.users if user != permute_node]
for user in users:
user.replace_input_with(node, permute_node)

def call(self, graph_module: torch.fx.GraphModule):
NHWC_Order = (0, 2, 3, 1)
HWCM_Order = (2, 3, 0, 1)
for node in graph_module.graph.nodes:
node_data = get_first_fake_tensor(node).data

if len(node_data.shape) == 4:
dim_order = NHWC_Order
if node_data.dim() == 4:
dim_order = self.NHWC_order
if self.is_weight_node_for_depthwise_conv2d(node):
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
dim_order = HWCM_Order
dim_order = self.HWCM_order
else:
dim_order = tuple(range(node_data.dim()))
node.meta["tosa_dim_order"] = dim_order
# Take care of cases when:
# 4D (NHWC) -> >4D (NCH)
# 3D (NCH) -> 4D (NHWC)
self.insert_tosa_transposes(graph_module)
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
op_squeeze,
op_sub,
op_sum,
op_transpose,
op_unsqueeze,
op_view,
)
42 changes: 42 additions & 0 deletions backends/arm/operators/op_transpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class TransposeVisitor(NodeVisitor):
"""
This node visitor targets the _transpose op defined in the
passthrough_to_tosa library. Used when switching between tosa_dim_orders.
Inserts a TOSA TRANSPOSE.
"""

target = "_transpose"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
output_rank = len(output.shape)
perms = [dim % output_rank for dim in inputs[1].special]
attr = ts.TosaSerializerAttribute()
attr.TransposeAttribute(perms)
tosa_graph.addOperator(
TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr
)
12 changes: 5 additions & 7 deletions backends/arm/test/ops/test_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from executorch.backends.arm.test.tester.arm_tester import ArmTester

from executorch.backends.xnnpack.test.tester.tester import Quantize
from executorch.exir.backend.backend_details import CompileSpec
from parameterized import parameterized


Expand Down Expand Up @@ -77,14 +78,14 @@ def _test_expand_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tupl
)

def _test_expand_ethosu_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple
self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: Tuple
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_u55_compile_spec(),
compile_spec=compile_spec,
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
Expand All @@ -104,17 +105,14 @@ def test_expand_tosa_MI(self, test_input, multiples):
def test_expand_tosa_BI(self, test_input, multiples):
self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples))

# Expected failure since tosa.TILE is unsupported by Vela.
@parameterized.expand(Expand.test_parameters)
@unittest.expectedFailure # TODO: MLBEDSW-9386
def test_expand_u55_BI(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
self.Expand(), common.get_u55_compile_spec(), (test_input, multiples)
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
)

@parameterized.expand(Expand.test_parameters)
@unittest.expectedFailure # TODO: MLBEDSW-9386
def test_expand_u85_BI(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
self.Expand(), common.get_u85_compile_spec(), (test_input, multiples)
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
)
2 changes: 0 additions & 2 deletions backends/arm/test/ops/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,12 @@ def test_repeat_tosa_BI(self, test_input, multiples):
self._test_repeat_tosa_BI_pipeline(self.Repeat(), (test_input, multiples))

@parameterized.expand(Repeat.test_parameters)
@unittest.expectedFailure # TODO: MLBEDSW-9386
def test_repeat_u55_BI(self, test_input, multiples):
self._test_repeat_ethosu_pipeline(
common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples)
)

@parameterized.expand(Repeat.test_parameters)
@unittest.expectedFailure # TODO: MLBEDSW-9386
def test_repeat_u85_BI(self, test_input, multiples):
self._test_repeat_ethosu_pipeline(
common.get_u85_compile_spec(), self.Repeat(), (test_input, multiples)
Expand Down
10 changes: 6 additions & 4 deletions backends/arm/test/ops/test_squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def forward(self, x: torch.Tensor, dim: int):

class SqueezeDims(torch.nn.Module):
test_parameters: list[tuple[torch.Tensor, tuple[int]]] = [
(torch.randn(1, 1, 5), (0, 1)),
(torch.randn(1, 5, 5, 1), (0, -1)),
(torch.randn(1, 5, 1, 5), (0, -2)),
]
Expand All @@ -47,6 +48,7 @@ def forward(self, x: torch.Tensor, dims: tuple[int]):

class Squeeze(torch.nn.Module):
test_parameters: list[tuple[torch.Tensor]] = [
(torch.randn(1, 1, 5),),
(torch.randn(1, 5, 5, 1),),
(torch.randn(1, 5, 1, 5),),
]
Expand All @@ -64,7 +66,7 @@ def _test_squeeze_tosa_MI_pipeline(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
compile_spec=common.get_tosa_compile_spec(),
)
.export()
.check_count({export_target: 1})
Expand All @@ -86,7 +88,7 @@ def _test_squeeze_tosa_BI_pipeline(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
compile_spec=common.get_tosa_compile_spec(),
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
Expand Down Expand Up @@ -184,7 +186,7 @@ def test_squeeze_dim_u55_BI(self, test_tensor: torch.Tensor, dim: int):
@parameterized.expand(SqueezeDim.test_parameters)
def test_squeeze_dim_u85_BI(self, test_tensor: torch.Tensor, dim: int):
self._test_squeeze_ethosu_BI_pipeline(
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
self.SqueezeDim(),
(test_tensor, dim),
"torch.ops.aten.squeeze.dim",
Expand Down Expand Up @@ -214,7 +216,7 @@ def test_squeeze_dims_u55_BI(self, test_tensor: torch.Tensor, dims: tuple[int]):
@parameterized.expand(SqueezeDims.test_parameters)
def test_squeeze_dims_u85_BI(self, test_tensor: torch.Tensor, dims: tuple[int]):
self._test_squeeze_ethosu_BI_pipeline(
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
common.get_u85_compile_spec(),
self.SqueezeDims(),
(test_tensor, dims),
"torch.ops.aten.squeeze.dims",
Expand Down
12 changes: 6 additions & 6 deletions backends/arm/test/ops/test_unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

class TestSimpleUnsqueeze(unittest.TestCase):
class Unsqueeze(torch.nn.Module):
shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 5), (5, 4, 3)]
shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 4), (5, 4, 3)]
test_parameters: list[tuple[torch.Tensor]] = [(torch.randn(n),) for n in shapes]

def forward(self, x: torch.Tensor, dim):
Expand All @@ -40,7 +40,7 @@ def _test_unsqueeze_tosa_MI_pipeline(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
compile_spec=common.get_tosa_compile_spec(),
)
.export()
.check_count({"torch.ops.aten.unsqueeze.default": 1})
Expand All @@ -59,7 +59,7 @@ def _test_unsqueeze_tosa_BI_pipeline(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
compile_spec=common.get_tosa_compile_spec(),
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
Expand Down Expand Up @@ -102,18 +102,18 @@ def test_unsqueeze_tosa_MI(self, test_tensor: torch.Tensor):
def test_unsqueeze_tosa_BI(self, test_tensor: torch.Tensor):
self._test_unsqueeze_tosa_BI_pipeline(self.Unsqueeze(), (test_tensor, 0))

@parameterized.expand(Unsqueeze.test_parameters)
@parameterized.expand(Unsqueeze.test_parameters[:-1])
def test_unsqueeze_u55_BI(self, test_tensor: torch.Tensor):
self._test_unsqueeze_ethosu_BI_pipeline(
common.get_u55_compile_spec(permute_memory_to_nhwc=False),
common.get_u55_compile_spec(),
self.Unsqueeze(),
(test_tensor, 0),
)

@parameterized.expand(Unsqueeze.test_parameters)
def test_unsqueeze_u85_BI(self, test_tensor: torch.Tensor):
self._test_unsqueeze_ethosu_BI_pipeline(
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
common.get_u85_compile_spec(),
self.Unsqueeze(),
(test_tensor, 0),
)
2 changes: 1 addition & 1 deletion examples/arm/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ function setup_vela() {
if [[ ! -e ethos-u-vela ]]; then
git clone https://review.mlplatform.org/ml/ethos-u/ethos-u-vela
repo_dir="${root_dir}/ethos-u-vela"
base_rev=fe0eaa55c5ed319f78c01978f3b40eb11a9bcb38
base_rev=57ce18c89ccc6f6309333dccb24ed30dc68b571f
patch_repo
fi
cd "${root_dir}/ethos-u-vela"
Expand Down
Loading