Skip to content

Commit 8167327

Browse files
committed
Revert "Arm backend: Move rescales from SUM visitor to pass (pytorch#15299)"
This reverts commit c4cd274.
1 parent eb2c876 commit 8167327

File tree

7 files changed

+83
-63
lines changed

7 files changed

+83
-63
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
194194
self.add_pass(ConvertExpandCopyToRepeatPass())
195195
self.add_pass(UnsqueezeBeforeRepeatPass())
196196
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
197+
self.add_pass(DecomposeSumPass())
197198
self.add_pass(DecomposeCumsumPass(exported_program))
198199
self.add_pass(Conv1dUnsqueezePass())
199200
self.add_pass(DecomposeMaxPool2DPass())
@@ -214,11 +215,10 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
214215
self.add_pass(RewriteMatmulPass())
215216
self.add_pass(RewriteUpsamplePass())
216217
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
217-
self.add_pass(InsertRescaleInt32Pass())
218-
self.add_pass(DecomposeSumPass())
219218
self.add_pass(ToTosaMemoryFormatPass(exported_program))
220219
self.add_pass(RemoveNoopPass())
221220
self.add_pass(InsertRescalePass())
221+
self.add_pass(InsertRescaleInt32Pass())
222222

223223
self.validate_constraints_mandatory()
224224
return self._transform(exported_program.graph_module)
@@ -361,6 +361,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
361361

362362
self.add_pass(ConvertMinMaxPass())
363363
self.add_pass(ReplaceInfValues())
364+
self.add_pass(DecomposeSumPass())
364365

365366
if not self.tosa_spec.is_U55_subset:
366367
# Uses where which is not supported on Ethos-U55

backends/arm/_passes/decompose_sum_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def call_operator(self, op, args, kwargs, meta):
8383
if not keepdims:
8484
shape = list(meta["val"].size())
8585
input_node = super().call_operator(
86-
view_op, (input_node, shape), {}, meta, updated=True
86+
view_op, (input_node, shape), kwargs, meta, updated=True
8787
)
8888

8989
return input_node

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch
1111
from executorch.backends.arm._passes.arm_pass import ArmPass
1212
from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg
13-
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
1413
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1514
get_output_qparams,
1615
)
@@ -85,11 +84,7 @@ class InsertRescaleInt32Pass(ArmPass):
8584
parameters.
8685
"""
8786

88-
# SUM must be decomposed after this pass to prevent insertion of RESCALE
89-
# nodes between each subsequent SUM node after decomposition. RESCALE nodes
90-
# should only be inserted before and after the SUM node prior to its
91-
# decomposition.
92-
_passes_required_after: Set[Type[ExportPass]] = {DecomposeSumPass}
87+
_passes_required_after: Set[Type[ExportPass]] = set()
9388

9489
included_targets = [
9590
exir_ops.edge.aten.abs.default,
@@ -101,7 +96,6 @@ class InsertRescaleInt32Pass(ArmPass):
10196
exir_ops.edge.aten.maximum.default,
10297
exir_ops.edge.aten.minimum.default,
10398
exir_ops.edge.aten.mul.Tensor,
104-
exir_ops.edge.aten.sum.dim_IntList,
10599
]
106100

107101
def _int32_qargs(self, s):
@@ -144,7 +138,6 @@ def _get_inputs_rescaled_qparams(
144138
}
145139
elif target in [
146140
exir_ops.edge.aten.mul.Tensor,
147-
exir_ops.edge.aten.sum.dim_IntList,
148141
]:
149142
# The input scales do not need to be adjusted for these ops; they
150143
# can remain the same.
@@ -167,7 +160,6 @@ def _get_output_qparams(
167160
exir_ops.edge.aten.abs.default,
168161
exir_ops.edge.aten.maximum.default,
169162
exir_ops.edge.aten.minimum.default,
170-
exir_ops.edge.aten.sum.dim_IntList,
171163
]:
172164
# The op has not altered the scale; the output scale is equal to
173165
# the operands' scales.

backends/arm/operator_support/reduce_sum_support.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,8 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
2929

3030
# U55 case, Vela 4.2.0 (25.02 release)
3131
input_shape = node.all_input_nodes[0].meta["val"].shape
32-
33-
if node.args[1] is None:
34-
# Dim is allowed to be None, which means to sum all dimensions
35-
dim_list = list(range(len(input_shape)))
36-
else:
37-
dim_list = cast(list[int], node.args[1])
38-
dim_list = [dim % len(input_shape) for dim in dim_list]
32+
dim_list = cast(list[int], node.args[1])
33+
dim_list = [dim % len(input_shape) for dim in dim_list]
3934

4035
for dim in dim_list:
4136
if not 1 <= input_shape[dim] <= 65536:

backends/arm/operators/op_sum.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from typing import Any, List
99

10+
import executorch.backends.arm.tosa.quant_utils as tqutils
11+
import executorch.backends.arm.tosa.utils as tutils
1012
import tosa_serializer as ts
1113

1214
from executorch.backends.arm.operators.node_visitor import (
@@ -23,14 +25,69 @@
2325

2426

2527
@register_node_visitor
26-
class SumVisitor(NodeVisitor):
28+
class SumVisitor_INT(NodeVisitor):
2729
target = "aten.sum.dim_IntList"
2830

2931
tosa_specs = [
30-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3132
TosaSpecification.create_from_string("TOSA-1.0+INT"),
3233
]
3334

35+
def __init__(self, *args):
36+
super().__init__(*args)
37+
38+
def define_node(
39+
self,
40+
node: Node,
41+
tosa_graph: Any,
42+
inputs: List[TosaArg],
43+
output: TosaArg,
44+
) -> None:
45+
validate_num_inputs(self.target, inputs, 3)
46+
validate_same_dtype(self.target, [inputs[0], output], ts)
47+
48+
tensor = inputs[0]
49+
input_shape = list(tensor.shape)
50+
dim = int(inputs[1].number % len(input_shape))
51+
52+
output_shape = input_shape
53+
output_shape[dim] = 1 # Output shape is input shape with dim reduced
54+
55+
# Rescale input to 32 bit
56+
rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32(
57+
tosa_graph, [tensor], node, self.tosa_spec
58+
)
59+
60+
attr = ts.TosaSerializerAttribute()
61+
attr.ReduceSumAttribute(tensor.dim_order.index(dim))
62+
63+
intermediate = tosa_graph.addIntermediate(
64+
tutils.tosa_shape(output_shape, tensor.dim_order),
65+
dtype=ts.DType.INT32,
66+
)
67+
68+
self._serialize_operator(
69+
node,
70+
tosa_graph,
71+
ts.Op.REDUCE_SUM,
72+
[rescaled_inputs[0].name],
73+
[intermediate.name],
74+
attr,
75+
)
76+
77+
tqutils.insert_rescale_op_to_int8(
78+
tosa_graph, intermediate, scale, node, self.tosa_spec
79+
)
80+
81+
82+
@register_node_visitor
83+
class SumVisitor_FP(SumVisitor_INT):
84+
# inheriting 'target' from INT class
85+
86+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
87+
88+
def __init__(self, *args):
89+
super().__init__(*args)
90+
3491
def define_node(
3592
self,
3693
node: Node,
@@ -45,6 +102,9 @@ def define_node(
45102
input_shape = list(tensor.shape)
46103
dim = int(inputs[1].number % len(input_shape))
47104

105+
output_shape = input_shape
106+
output_shape[dim] = 1 # Output shape is input shape with dim reduced
107+
48108
attr = ts.TosaSerializerAttribute()
49109
attr.ReduceSumAttribute(tensor.dim_order.index(dim))
50110

backends/arm/test/ops/test_sum.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class Sum(torch.nn.Module):
3535
"4d_dim_3_keep": lambda: (torch.rand(1, 2, 3, 4), 3, True),
3636
"4d_dims_keep": lambda: (torch.rand(1, 2, 8, 8), [2, 3, 0], True),
3737
"dim_None": lambda: (torch.rand(10), None, True),
38-
"dim_None_4d_tensor": lambda: (torch.rand(10, 3, 2, 1), None, True),
3938
}
4039

4140
def forward(self, x: torch.Tensor, dim: int, keepdim: bool):

backends/arm/test/passes/test_insert_rescale_i32_pass.py

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1414

1515

16-
class MultipleOpsModel(torch.nn.Module):
16+
class NeedsRescaleOps(torch.nn.Module):
1717
"""A module containing ops that require INT32 inputs/outputs."""
1818

1919
input_t = Tuple[torch.Tensor, torch.Tensor]
2020

21+
def __init__(self):
22+
super().__init__()
23+
2124
def forward(self, x, y):
2225
a = x * y
2326
b = torch.maximum(a, y)
@@ -36,41 +39,19 @@ def get_inputs(self, dtype) -> input_t:
3639
else:
3740
raise ValueError("Not a valid input dtype for model")
3841

39-
def get_num_expected_rescales(self):
40-
# "number of op nodes with i8 output" + "number of i8 node inputs"
41-
return 3 + 7
42-
43-
44-
class SumModel(torch.nn.Module):
45-
input_t = Tuple[torch.Tensor]
46-
47-
def forward(self, x):
48-
a = torch.sum(x, 2, keepdim=True) # (1, 2, 1, 4)
49-
b = torch.sum(a, [1, 3], keepdim=True) # (1, 1, 1, 1)
50-
c = torch.sum(b, [0, 2], keepdim=False) # (1, 1)
51-
return c
52-
53-
def get_inputs(self, dtype) -> input_t:
54-
if dtype == torch.float32:
55-
return (torch.rand(1, 2, 3, 4),)
56-
elif dtype == torch.int32:
57-
return (torch.randint(0, 10, (1, 2, 3, 4), dtype=torch.int32),)
58-
else:
59-
raise ValueError("Not a valid input dtype for model")
6042

61-
def get_num_expected_rescales(self):
62-
# Two RESCALE nodes per SUM node
63-
return 6
64-
65-
66-
def _test_model_with_f32_data(model):
43+
def test_insert_rescales():
44+
module = NeedsRescaleOps()
45+
input_t = Tuple[torch.Tensor, torch.Tensor]
6746
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
6847
ops_after = {
69-
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": model.get_num_expected_rescales(),
48+
# "number of op nodes with i8 output" + "number of i8 node inputs"
49+
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3
50+
+ 7,
7051
}
71-
pipeline = PassPipeline[model.input_t](
72-
model,
73-
model.get_inputs(torch.float32),
52+
pipeline = PassPipeline[input_t](
53+
module,
54+
module.get_inputs(torch.float32),
7455
quantize=True,
7556
ops_not_before_pass=ops_not_before,
7657
ops_after_pass=ops_after,
@@ -80,16 +61,8 @@ def _test_model_with_f32_data(model):
8061
pipeline.run()
8162

8263

83-
def test_insert_rescales_sum_model():
84-
_test_model_with_f32_data(SumModel())
85-
86-
87-
def test_insert_rescales_multiple_ops_model():
88-
_test_model_with_f32_data(MultipleOpsModel())
89-
90-
9164
def test_dont_insert_rescales():
92-
module = MultipleOpsModel()
65+
module = NeedsRescaleOps()
9366
input_t = Tuple[torch.Tensor, torch.Tensor]
9467
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
9568
# All inputs are already i32. Rescales should not be added.

0 commit comments

Comments
 (0)