Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
self.add_pass(DecomposeSumPass())
self.add_pass(DecomposeCumsumPass(exported_program))
self.add_pass(Conv1dUnsqueezePass())
self.add_pass(DecomposeMaxPool2DPass())
Expand All @@ -215,10 +214,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(RewriteMatmulPass())
self.add_pass(RewriteUpsamplePass())
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(InsertRescaleInt32Pass())
self.add_pass(DecomposeSumPass())
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(RemoveNoopPass())
self.add_pass(InsertRescalePass())
self.add_pass(InsertRescaleInt32Pass())

self.validate_constraints_mandatory()
return self._transform(exported_program.graph_module)
Expand Down Expand Up @@ -361,7 +361,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):

self.add_pass(ConvertMinMaxPass())
self.add_pass(ReplaceInfValues())
self.add_pass(DecomposeSumPass())

if not self.tosa_spec.is_U55_subset:
# Uses where which is not supported on Ethos-U55
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/decompose_sum_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def call_operator(self, op, args, kwargs, meta):
if not keepdims:
shape = list(meta["val"].size())
input_node = super().call_operator(
view_op, (input_node, shape), kwargs, meta, updated=True
view_op, (input_node, shape), {}, meta, updated=True
)

return input_node
10 changes: 9 additions & 1 deletion backends/arm/_passes/insert_rescales_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_output_qparams,
)
Expand Down Expand Up @@ -84,7 +85,11 @@ class InsertRescaleInt32Pass(ArmPass):
parameters.
"""

_passes_required_after: Set[Type[ExportPass]] = set()
# SUM must be decomposed after this pass to prevent insertion of RESCALE
# nodes between each subsequent SUM node after decomposition. RESCALE nodes
# should only be inserted before and after the SUM node prior to its
# decomposition.
_passes_required_after: Set[Type[ExportPass]] = {DecomposeSumPass}

included_targets = [
exir_ops.edge.aten.abs.default,
Expand All @@ -96,6 +101,7 @@ class InsertRescaleInt32Pass(ArmPass):
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
]

def _int32_qargs(self, s):
Expand Down Expand Up @@ -138,6 +144,7 @@ def _get_inputs_rescaled_qparams(
}
elif target in [
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
]:
# The input scales do not need to be adjusted for these ops; they
# can remain the same.
Expand All @@ -160,6 +167,7 @@ def _get_output_qparams(
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.sum.dim_IntList,
]:
# The op has not altered the scale; the output scale is equal to
# the operands' scales.
Expand Down
9 changes: 7 additions & 2 deletions backends/arm/operator_support/reduce_sum_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

# U55 case, Vela 4.2.0 (25.02 release)
input_shape = node.all_input_nodes[0].meta["val"].shape
dim_list = cast(list[int], node.args[1])
dim_list = [dim % len(input_shape) for dim in dim_list]

if node.args[1] is None:
# Dim is allowed to be None, which means to sum all dimensions
dim_list = list(range(len(input_shape)))
else:
dim_list = cast(list[int], node.args[1])
dim_list = [dim % len(input_shape) for dim in dim_list]

for dim in dim_list:
if not 1 <= input_shape[dim] <= 65536:
Expand Down
66 changes: 3 additions & 63 deletions backends/arm/operators/op_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from typing import Any, List

import executorch.backends.arm.tosa.quant_utils as tqutils
import executorch.backends.arm.tosa.utils as tutils
import serializer.tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
Expand All @@ -25,69 +23,14 @@


@register_node_visitor
class SumVisitor_INT(NodeVisitor):
class SumVisitor(NodeVisitor):
target = "aten.sum.dim_IntList"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+FP"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output], ts)

tensor = inputs[0]
input_shape = list(tensor.shape)
dim = int(inputs[1].number % len(input_shape))

output_shape = input_shape
output_shape[dim] = 1 # Output shape is input shape with dim reduced

# Rescale input to 32 bit
rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32(
tosa_graph, [tensor], node, self.tosa_spec
)

attr = ts.TosaSerializerAttribute()
attr.ReduceSumAttribute(tensor.dim_order.index(dim))

intermediate = tosa_graph.addIntermediate(
tutils.tosa_shape(output_shape, tensor.dim_order),
dtype=ts.DType.INT32,
)

self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().REDUCE_SUM,
[rescaled_inputs[0].name],
[intermediate.name],
attr,
)

tqutils.insert_rescale_op_to_int8(
tosa_graph, intermediate, scale, node, self.tosa_spec
)


@register_node_visitor
class SumVisitor_FP(SumVisitor_INT):
# inheriting 'target' from INT class

tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
Expand All @@ -102,17 +45,14 @@ def define_node(
input_shape = list(tensor.shape)
dim = int(inputs[1].number % len(input_shape))

output_shape = input_shape
output_shape[dim] = 1 # Output shape is input shape with dim reduced

attr = ts.TosaSerializerAttribute()
attr.ReduceSumAttribute(tensor.dim_order.index(dim))

self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().REDUCE_SUM,
[tensor.name],
[inputs[0].name],
[output.name],
attr,
)
1 change: 1 addition & 0 deletions backends/arm/test/ops/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Sum(torch.nn.Module):
"4d_dim_3_keep": lambda: (torch.rand(1, 2, 3, 4), 3, True),
"4d_dims_keep": lambda: (torch.rand(1, 2, 8, 8), [2, 3, 0], True),
"dim_None": lambda: (torch.rand(10), None, True),
"dim_None_4d_tensor": lambda: (torch.rand(10, 3, 2, 1), None, True),
}

def forward(self, x: torch.Tensor, dim: int, keepdim: bool):
Expand Down
55 changes: 41 additions & 14 deletions backends/arm/test/passes/test_insert_rescale_i32_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline


class NeedsRescaleOps(torch.nn.Module):
class MultipleOpsModel(torch.nn.Module):
"""A module containing ops that require INT32 inputs/outputs."""

input_t = Tuple[torch.Tensor, torch.Tensor]

def __init__(self):
super().__init__()

def forward(self, x, y):
a = x * y
b = torch.maximum(a, y)
Expand All @@ -39,19 +36,41 @@ def get_inputs(self, dtype) -> input_t:
else:
raise ValueError("Not a valid input dtype for model")

def get_num_expected_rescales(self):
# "number of op nodes with i8 output" + "number of i8 node inputs"
return 3 + 7

def test_insert_rescales():
module = NeedsRescaleOps()
input_t = Tuple[torch.Tensor, torch.Tensor]

class SumModel(torch.nn.Module):
input_t = Tuple[torch.Tensor]

def forward(self, x):
a = torch.sum(x, 2, keepdim=True) # (1, 2, 1, 4)
b = torch.sum(a, [1, 3], keepdim=True) # (1, 1, 1, 1)
c = torch.sum(b, [0, 2], keepdim=False) # (1, 1)
return c

def get_inputs(self, dtype) -> input_t:
if dtype == torch.float32:
return (torch.rand(1, 2, 3, 4),)
elif dtype == torch.int32:
return (torch.randint(0, 10, (1, 2, 3, 4), dtype=torch.int32),)
else:
raise ValueError("Not a valid input dtype for model")

def get_num_expected_rescales(self):
# Two RESCALE nodes per SUM node
return 6


def _test_model_with_f32_data(model):
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
ops_after = {
# "number of op nodes with i8 output" + "number of i8 node inputs"
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3
+ 7,
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": model.get_num_expected_rescales(),
}
pipeline = PassPipeline[input_t](
module,
module.get_inputs(torch.float32),
pipeline = PassPipeline[model.input_t](
model,
model.get_inputs(torch.float32),
quantize=True,
ops_not_before_pass=ops_not_before,
ops_after_pass=ops_after,
Expand All @@ -61,8 +80,16 @@ def test_insert_rescales():
pipeline.run()


def test_insert_rescales_sum_model():
_test_model_with_f32_data(SumModel())


def test_insert_rescales_multiple_ops_model():
_test_model_with_f32_data(MultipleOpsModel())


def test_dont_insert_rescales():
module = NeedsRescaleOps()
module = MultipleOpsModel()
input_t = Tuple[torch.Tensor, torch.Tensor]
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
# All inputs are already i32. Rescales should not be added.
Expand Down
Loading