Skip to content

Commit b8e0ef9

Browse files
authored
Add cat/stack ops to generic annotator
Differential Revision: D65067790 Pull Request resolved: #6494
1 parent c438f8d commit b8e0ef9

File tree

10 files changed

+44
-134
lines changed

10 files changed

+44
-134
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ class ArmQuantizer(Quantizer):
268268
"sub",
269269
"mul",
270270
"mm",
271-
"cat",
272271
"one_to_one",
273272
"generic",
274273
"sum",

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,21 +144,10 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
144144
torch.ops.aten.mean.dim,
145145
torch.ops.aten.permute.default,
146146
torch.ops.aten.permute_copy.default,
147-
torch.ops.aten.squeeze.dim,
148-
torch.ops.aten.squeeze.dims,
149-
torch.ops.aten.squeeze.default,
150-
torch.ops.aten.squeeze_copy.dim,
151-
torch.ops.aten.unsqueeze.default,
152-
torch.ops.aten.unsqueeze_copy.default,
153147
# TODO: remove?
154148
torch.ops.aten.adaptive_avg_pool2d.default,
155149
torch.ops.aten.avg_pool2d.default,
156-
torch.ops.aten.view_copy.default,
157-
torch.ops.aten.view.default,
158150
torch.ops.aten.full.default,
159-
torch.ops.aten.slice.Tensor,
160-
torch.ops.aten.split.Tensor,
161-
torch.ops.aten.split_with_sizes.default,
162151
torch.ops.aten.flatten.using_ints,
163152
torch.ops.aten.dropout.default,
164153
operator.getitem,

backends/arm/quantizer/quantization_annotation/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def decorator(annotator: AnnotatorType):
5151
from . import ( # noqa
5252
adaptive_ang_pool2d_annotator,
5353
add_annotator,
54-
cat_annotator,
5554
conv_annotator,
5655
generic_annotator,
5756
linear_annotator,

backends/arm/quantizer/quantization_annotation/cat_annotator.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

backends/arm/quantizer/quantization_annotation/generic_annotator.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
87
from typing import Callable, List, Optional
98

109
import torch
@@ -24,6 +23,9 @@
2423
# DATA LAYOUT OPS
2524
torch.ops.aten.squeeze.default,
2625
torch.ops.aten.squeeze_copy.default,
26+
torch.ops.aten.squeeze_copy.dim,
27+
torch.ops.aten.squeeze.dim,
28+
torch.ops.aten.squeeze.dims,
2729
torch.ops.aten.unsqueeze.default,
2830
torch.ops.aten.unsqueeze_copy.default,
2931
torch.ops.aten.reshape.default,
@@ -33,19 +35,21 @@
3335
# torch.ops.aten.view_as_complex_copy.default,
3436
# torch.ops.aten.view_as_real.default,
3537
# torch.ops.aten.view_as_real_copy.default,
38+
torch.ops.aten.view.default,
3639
torch.ops.aten.view_copy.default,
3740
torch.ops.aten.select.int,
3841
torch.ops.aten.select_copy.int,
3942
torch.ops.aten.slice.Tensor,
4043
torch.ops.aten.slice_copy.Tensor,
41-
# 'concat' should be handled separately as it has a sequence of inputs and
42-
# makes the implementation unnecessary complicated.
43-
# torch.ops.aten.concat.default,
44+
torch.ops.aten.split.Tensor,
45+
torch.ops.aten.split_with_sizes.default,
4446
torch.ops.aten.transpose.Dimname,
4547
torch.ops.aten.transpose.int,
4648
torch.ops.aten.transpose_copy.int,
4749
torch.ops.aten.tile.default,
4850
torch.ops.aten.flip.default,
51+
torch.ops.aten.cat.default,
52+
torch.ops.aten.stack.default,
4953
]
5054

5155

@@ -66,15 +70,31 @@ def _annotate_generic(
6670
if arm_quantizer_utils.is_annotated(node):
6771
continue
6872

69-
input_node = node.args[0]
73+
input_acts = node.args[0]
74+
75+
# Check to see if there are multiple inputs.
76+
# this allows for stack/cat ops to be annotated
77+
# in a similar way.
78+
has_multi_inputs = isinstance(input_acts, list)
79+
80+
input_act0 = input_acts[0] if has_multi_inputs else input_acts
7081

7182
# Using a non-shared quantization spec here as a SharedQuantizationSpec
7283
# can lead to a recursion.
7384
_annotate_input_qspec_map(
74-
node, input_node, quantization_config.get_input_act_qspec()
85+
node, input_act0, quantization_config.get_input_act_qspec()
7586
)
76-
_annotate_output_qspec(node, SharedQuantizationSpec((input_node, node)))
87+
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node))
88+
89+
if has_multi_inputs:
90+
# For the rest of the inputs, share qspec with first.
91+
for input_act in input_acts[1:]:
92+
if input_act is not input_act0:
93+
node.meta["quantization_annotation"].input_qspec_map[
94+
input_act
95+
] = shared_with_input0_qspec
7796

97+
_annotate_output_qspec(node, shared_with_input0_qspec)
7898
arm_quantizer_utils.mark_nodes_as_annotated([node])
7999
annotated_partitions.append([node])
80100

backends/arm/test/ops/test_slice.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,9 @@
88
from typing import Tuple
99

1010
import torch
11-
from executorch.backends.arm.quantizer.arm_quantizer import (
12-
ArmQuantizer,
13-
get_symmetric_quantization_config,
14-
)
11+
1512
from executorch.backends.arm.test import common
1613
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17-
from executorch.backends.xnnpack.test.tester.tester import Quantize
1814
from executorch.exir.backend.compile_spec_schema import CompileSpec
1915
from parameterized import parameterized
2016

@@ -59,7 +55,6 @@ def _test_slice_tosa_BI_pipeline(
5955
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor], permute: bool
6056
):
6157

62-
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
6358
(
6459
ArmTester(
6560
module,
@@ -68,7 +63,7 @@ def _test_slice_tosa_BI_pipeline(
6863
permute_memory_to_nhwc=permute
6964
),
7065
)
71-
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
66+
.quantize()
7267
.export()
7368
.check(["torch.ops.aten.slice.Tensor"])
7469
.to_edge()
@@ -84,14 +79,13 @@ def _test_slice_ethos_BI_pipeline(
8479
module: torch.nn.Module,
8580
test_data: Tuple[torch.Tensor],
8681
):
87-
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
8882
(
8983
ArmTester(
9084
module,
9185
example_inputs=test_data,
9286
compile_spec=common.get_u55_compile_spec(),
9387
)
94-
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
88+
.quantize()
9589
.export()
9690
.check(["torch.ops.aten.slice.Tensor"])
9791
.to_edge()

backends/arm/test/ops/test_split.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,9 @@
77
import unittest
88

99
import torch
10-
from executorch.backends.arm.quantizer.arm_quantizer import (
11-
ArmQuantizer,
12-
get_symmetric_quantization_config,
13-
)
10+
1411
from executorch.backends.arm.test import common
1512
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16-
from executorch.backends.xnnpack.test.tester.tester import Quantize
1713
from executorch.exir.backend.compile_spec_schema import CompileSpec
1814
from parameterized import parameterized
1915

@@ -79,14 +75,13 @@ def _test_split_tosa_BI_pipeline(
7975
self, module: torch.nn.Module, test_data: test_data_t
8076
):
8177

82-
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
8378
(
8479
ArmTester(
8580
module,
8681
example_inputs=test_data,
8782
compile_spec=common.get_tosa_compile_spec(),
8883
)
89-
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
84+
.quantize()
9085
.export()
9186
.to_edge()
9287
.partition()
@@ -98,14 +93,13 @@ def _test_split_tosa_BI_pipeline(
9893
def _test_split_ethosu_BI_pipeline(
9994
self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: test_data_t
10095
):
101-
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
10296
(
10397
ArmTester(
10498
module,
10599
example_inputs=test_data,
106100
compile_spec=compile_spec,
107101
)
108-
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
102+
.quantize()
109103
.export()
110104
.check(["torch.ops.aten.split.Tensor"])
111105
.to_edge()

backends/arm/test/ops/test_squeeze.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,9 @@
1313

1414
import torch
1515

16-
from executorch.backends.arm.quantizer.arm_quantizer import (
17-
ArmQuantizer,
18-
get_symmetric_quantization_config,
19-
)
2016
from executorch.backends.arm.test import common
2117
from executorch.backends.arm.test.tester.arm_tester import ArmTester
2218

23-
from executorch.backends.xnnpack.test.tester.tester import Quantize
2419
from executorch.exir.backend.compile_spec_schema import CompileSpec
2520
from parameterized import parameterized
2621

@@ -83,14 +78,13 @@ def _test_squeeze_tosa_BI_pipeline(
8378
test_data: Tuple[torch.Tensor, Optional[tuple[int]]],
8479
export_target: str,
8580
):
86-
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
8781
(
8882
ArmTester(
8983
module,
9084
example_inputs=test_data,
9185
compile_spec=common.get_tosa_compile_spec(),
9286
)
93-
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
87+
.quantize()
9488
.export()
9589
.check_count({export_target: 1})
9690
.to_edge()
@@ -107,10 +101,9 @@ def _test_squeeze_ethosu_BI_pipeline(
107101
test_data: Tuple[torch.Tensor, Optional[tuple[int]]],
108102
export_target: str,
109103
):
110-
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
111104
(
112105
ArmTester(module, example_inputs=test_data, compile_spec=compile_spec)
113-
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
106+
.quantize()
114107
.export()
115108
.check_count({export_target: 1})
116109
.to_edge()

backends/arm/test/ops/test_unsqueeze.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,9 @@
1313

1414
import torch
1515

16-
from executorch.backends.arm.quantizer.arm_quantizer import (
17-
ArmQuantizer,
18-
get_symmetric_quantization_config,
19-
)
2016
from executorch.backends.arm.test import common
2117
from executorch.backends.arm.test.tester.arm_tester import ArmTester
2218

23-
from executorch.backends.xnnpack.test.tester.tester import Quantize
2419
from executorch.exir.backend.compile_spec_schema import CompileSpec
2520
from parameterized import parameterized
2621

@@ -54,14 +49,13 @@ def _test_unsqueeze_tosa_MI_pipeline(
5449
def _test_unsqueeze_tosa_BI_pipeline(
5550
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, int]
5651
):
57-
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
5852
(
5953
ArmTester(
6054
module,
6155
example_inputs=test_data,
6256
compile_spec=common.get_tosa_compile_spec(),
6357
)
64-
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
58+
.quantize()
6559
.export()
6660
.check_count({"torch.ops.aten.unsqueeze.default": 1})
6761
.to_edge()
@@ -77,14 +71,13 @@ def _test_unsqueeze_ethosu_BI_pipeline(
7771
module: torch.nn.Module,
7872
test_data: Tuple[torch.Tensor, int],
7973
):
80-
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
8174
(
8275
ArmTester(
8376
module,
8477
example_inputs=test_data,
8578
compile_spec=compile_spec,
8679
)
87-
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
80+
.quantize()
8881
.export()
8982
.check_count({"torch.ops.aten.unsqueeze.default": 1})
9083
.to_edge()

0 commit comments

Comments
 (0)