Skip to content

Commit 5b3d9fc

Browse files
Arm backend: Add partial support for aten.gather (pytorch#16561)
- Canonicalize edge.aten.gather to backend dialect tosa.GATHER - Register TOSA dialect op for GATHER - Add GatherVisitor lowering for tosa.GATHER - Add GatherSupported check for the restricted 2D gather pattern Change-Id: I0c31079a46bd3a2309ac337eff7824b7a8c0c661 cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Yufeng Shi <[email protected]> Co-authored-by: Sicheng Stephen Jia <[email protected]>
1 parent 4d5f330 commit 5b3d9fc

File tree

13 files changed

+596
-1
lines changed

13 files changed

+596
-1
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
1111
from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa
1212
from .broadcast_args_pass import BroadcastArgsPass # noqa
13+
from .canonicalize_gather_pass import CanonicalizeGatherPass # noqa
1314
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1415
from .cast_to_int32_pass import CastToInt32Pass # noqa
1516
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
AnnotateDecomposedMatmulPass,
1616
AnnotateOutputDimOrderPass,
1717
BroadcastArgsPass,
18+
CanonicalizeGatherPass,
1819
CastInt64BuffersToInt32Pass,
1920
CastToInt32Pass,
2021
ComputeConstantOpsAOTPass,
@@ -228,6 +229,7 @@ def _tosa_pipeline(
228229
FuseQuantizedActivationPass(),
229230
RewriteBoolBitwiseNotToLogicalNotPass(),
230231
RewriteBoolToFp32CastViaInt8Pass(),
232+
CanonicalizeGatherPass(),
231233
ConvertToClampPass(),
232234
DecomposeTOSAUnsupportedClampPass(),
233235
DecomposeGroupNormPass(),
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import logging
8+
from typing import Set, Type
9+
10+
import torch
11+
from executorch.backends.arm._passes import ArmPass
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class CanonicalizeGatherPass(ArmPass):
19+
"""
20+
Canonicalize gather so it can be lowered to TOSA.GATHER via the backend dialect.
21+
22+
This pass is intended to run only for nodes already gated by GatherSupported.
23+
24+
Behavior:
25+
- Reshape x from [N,K] to [N,K,1] so values matches TOSA gather's [N,K,C].
26+
- Keep indices as [N,W]
27+
- Lower using tosa.GATHER.default, producing [N,W,1].
28+
- Reshape output to [N,W].
29+
- Only insert bool<->int8 casts when x is bool:
30+
* If x is bool: gather runs on int8 and output is cast back to bool.
31+
* If x is not bool: gather runs on original dtype and output keeps dtype.
32+
"""
33+
34+
_passes_required_after: Set[Type[ExportPass]] = set()
35+
36+
_TARGET_OPS = {exir_ops.edge.aten.gather.default}
37+
38+
def call_operator(self, op, args, kwargs, meta):
39+
if op not in self._TARGET_OPS:
40+
return super().call_operator(op, args, kwargs, meta)
41+
42+
# edge.aten.gather.default: (x, dim, index) with kw-only sparse_grad
43+
x, dim, index = args
44+
45+
# GatherSupported should have gated this already; treat violations as errors.
46+
x_shape = x.data.shape
47+
index_shape = index.data.shape
48+
if not (
49+
dim in (1, -1)
50+
and len(x_shape) == 2
51+
and len(index_shape) == 2
52+
and index_shape[0] == x_shape[0]
53+
):
54+
raise RuntimeError(
55+
f"[{op}] Unexpected gather pattern; expected "
56+
f"x:[N,K], index:[N,W], dim in {{1,-1}}, matching N. "
57+
f"Got dim={dim}, x.shape={x_shape}, index.shape={index_shape}."
58+
)
59+
60+
N, K = x_shape[0], x_shape[1]
61+
W = index_shape[1]
62+
63+
view_op = exir_ops.edge.aten.view_copy.default
64+
to_copy_op = exir_ops.edge.dim_order_ops._to_dim_order_copy.default
65+
66+
# Use backend dialect gather:
67+
# values: [N,K,C]
68+
# indices: [N,W]
69+
# output: [N,W,C]
70+
tosa_gather_op = exir_ops.backend.tosa.GATHER.default
71+
72+
needs_bool_cast = x.data.dtype == torch.bool
73+
74+
# bool -> int8 (only if needed)
75+
values_in = x
76+
if needs_bool_cast:
77+
values_in = super().call_operator(
78+
to_copy_op,
79+
(x,),
80+
{"dtype": torch.int8},
81+
meta,
82+
updated=True,
83+
)
84+
85+
# [N,K] -> [N,K,1]
86+
values_3d = super().call_operator(
87+
view_op,
88+
(values_in, [N, K, 1]),
89+
{},
90+
meta,
91+
updated=True,
92+
)
93+
94+
# indices stays [N,W]
95+
gathered_3d = super().call_operator(
96+
tosa_gather_op,
97+
(values_3d, index),
98+
{},
99+
meta,
100+
updated=True,
101+
)
102+
103+
# [N,W,1] -> [N,W]
104+
gathered_2d = super().call_operator(
105+
view_op,
106+
(gathered_3d, [N, W]),
107+
{},
108+
meta,
109+
updated=True,
110+
)
111+
112+
# int8 -> bool (only if needed)
113+
if needs_bool_cast:
114+
return super().call_operator(
115+
to_copy_op,
116+
(gathered_2d,),
117+
{"dtype": torch.bool},
118+
meta,
119+
updated=True,
120+
)
121+
122+
return gathered_2d

backends/arm/operator_support/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -10,6 +10,7 @@
1010
convolution_support,
1111
embedding_support,
1212
ethos_u55_support,
13+
gather_support,
1314
index_select_support,
1415
index_tensor_support,
1516
minmax_support,

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class EthosU55NotSupported(OperatorSupportBase):
204204
exir_ops.edge.aten.ne.Tensor,
205205
exir_ops.edge.aten.ne.Scalar,
206206
exir_ops.edge.aten.flip.default, # REVERSE
207+
exir_ops.edge.aten.gather.default, # GATHER
207208
exir_ops.edge.aten.grid_sampler_2d, # GATHER
208209
exir_ops.edge.aten.index.Tensor, # GATHER
209210
exir_ops.edge.aten.index_select.default, # GATHER
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""
6+
Declare operator support for ``edge.aten.gather`` in TOSA.
7+
8+
This support check matches the subset accepted by CanonicalizeGatherPass:
9+
10+
- target: exir_ops.edge.aten.gather.default
11+
- args: exactly (x, dim, index) (i.e. len(node.args) == 3)
12+
- dim must be 1 or -1
13+
- x must be rank-2
14+
- index must be rank-2
15+
- index dtype must be int32
16+
- batch dim must match: x.shape[0] == index.shape[0]
17+
18+
Dtype gating is capability-based:
19+
20+
- int8/int16/int32 values require INT profile.
21+
- bool values require INT profile (handled via casts: bool -> int8 -> bool).
22+
- fp16/fp32 values are supported via FP profile directly, or via quantization
23+
when running under an INT profile.
24+
25+
Note:
26+
- CanonicalizeGatherPass reshapes values to [N, K, 1] and keeps indices as [N, W],
27+
then lowers via the TOSA gather dialect.
28+
"""
29+
30+
import torch
31+
import torch.fx as fx
32+
33+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
34+
register_tosa_support_check,
35+
SupportedTOSAOperatorCheck,
36+
)
37+
from executorch.backends.arm.tosa import TosaSpecification
38+
from executorch.exir.dialects._ops import ops as exir_ops
39+
40+
41+
@register_tosa_support_check
42+
class GatherSupported(SupportedTOSAOperatorCheck):
43+
"""Provide TOSA support check for ``edge.aten.gather``."""
44+
45+
targets = [exir_ops.edge.aten.gather.default]
46+
47+
tosa_specs = [
48+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
49+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
50+
]
51+
52+
def is_node_tosa_supported(
53+
self, node: fx.Node, tosa_spec: TosaSpecification
54+
) -> bool: # type: ignore[override, misc]
55+
if len(node.args) != 3:
56+
self.reporter.report_reject(
57+
node,
58+
f"{node.target}: expected 3 args (x, dim, index), got "
59+
f"{len(node.args)}.",
60+
)
61+
return False
62+
63+
x_arg, dim, index_arg = node.args[0], node.args[1], node.args[2]
64+
x_val = x_arg.meta["val"] # type: ignore[union-attr]
65+
index_val = index_arg.meta["val"] # type: ignore[union-attr]
66+
67+
x_shape = tuple(x_val.shape)
68+
index_shape = tuple(index_val.shape)
69+
70+
# ---- index dtype ----
71+
if index_val.dtype != torch.int32:
72+
self.reporter.report_reject(
73+
node,
74+
f"{node.target}: index dtype {index_val.dtype} not supported; "
75+
"expected int32.",
76+
)
77+
return False
78+
79+
# ---- dim + rank ----
80+
if not (
81+
(dim == 1 or dim == -1) and len(x_shape) == 2 and len(index_shape) == 2
82+
):
83+
self.reporter.report_reject(
84+
node,
85+
f"{node.target}: unsupported dim/rank; got {dim=}, "
86+
f"x_rank={len(x_shape)}, index_rank={len(index_shape)}; "
87+
"supported: dim in {1, -1} with rank-2 x and rank-2 index.",
88+
)
89+
return False
90+
91+
# ---- batch dim compatibility ----
92+
if x_shape[0] != index_shape[0]:
93+
self.reporter.report_reject(
94+
node,
95+
f"{node.target}: batch mismatch {x_shape[0]=} vs {index_shape[0]=}.",
96+
)
97+
return False
98+
99+
# ---- values dtype ----
100+
values_dtype = x_val.dtype
101+
# ints (and bool via casts) require INT profile
102+
if values_dtype in (torch.bool, torch.int8, torch.int16, torch.int32):
103+
if not tosa_spec.support_integer():
104+
self.reporter.report_reject(
105+
node,
106+
f"{node.target}: dtype {values_dtype} requires INT profile.",
107+
)
108+
return False
109+
# fp16/fp32: either FP profile, or INT profile (via quantization)
110+
elif values_dtype in (torch.float16, torch.float32):
111+
if not (tosa_spec.support_float() or tosa_spec.support_integer()):
112+
self.reporter.report_reject(
113+
node,
114+
f"{node.target}: dtype {values_dtype} requires FP profile or "
115+
"INT profile (with quantization).",
116+
)
117+
return False
118+
else:
119+
self.reporter.report_reject(
120+
node,
121+
f"{node.target}: unsupported values dtype {values_dtype}; "
122+
"expected bool/int8/int16/int32/float16/float32.",
123+
)
124+
return False
125+
126+
return True

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
op_tosa_conv2d,
5959
op_tosa_conv3d,
6060
op_tosa_depthwise_conv2d,
61+
op_tosa_gather,
6162
op_tosa_matmul,
6263
op_tosa_rescale,
6364
op_tosa_resize,
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, List
7+
8+
import tosa_serializer as ts
9+
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.operators.operator_validation_utils import (
15+
validate_num_inputs,
16+
validate_same_dtype,
17+
validate_valid_dtype,
18+
)
19+
from executorch.backends.arm.tosa.mapping import TosaArg
20+
from torch.fx import Node
21+
22+
23+
@register_node_visitor
24+
class GatherVisitor(NodeVisitor):
25+
"""
26+
Lowers backend TOSA dialect `tosa.GATHER.default`.
27+
28+
Expected signature (per TOSA):
29+
values: [N, K, C] (rank 3)
30+
indices: [N, W] (rank 2, int32)
31+
output: [N, W, C] (rank 3)
32+
"""
33+
34+
target = "tosa.GATHER.default"
35+
tosa_specs = NodeVisitor.tosa_specs
36+
37+
def define_node(
38+
self,
39+
node: Node,
40+
tosa_graph: Any,
41+
inputs: List[TosaArg],
42+
output: TosaArg,
43+
) -> None:
44+
validate_num_inputs(self.target, inputs, 2)
45+
46+
values = inputs[0]
47+
indices = inputs[1]
48+
49+
validate_same_dtype(self.target, [values, output], ts)
50+
# Indices must be int32 for TOSA GATHER
51+
validate_valid_dtype(
52+
self.target,
53+
[indices],
54+
[ts.DType.INT32],
55+
output.tosa_spec,
56+
)
57+
validate_valid_dtype(
58+
self.target,
59+
[values, output],
60+
[
61+
ts.DType.INT8,
62+
ts.DType.INT16,
63+
ts.DType.INT32,
64+
ts.DType.FP16,
65+
ts.DType.FP32,
66+
],
67+
output.tosa_spec,
68+
)
69+
70+
attr = ts.TosaSerializerAttribute()
71+
attr.GatherAttribute()
72+
73+
self._serialize_operator(
74+
node,
75+
tosa_graph,
76+
ts.Op.GATHER,
77+
[values.name, indices.name],
78+
[output.name],
79+
attr,
80+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ def _match_pattern(
428428
torch.ops.aten.clamp.default,
429429
torch.ops.aten.clamp.Tensor,
430430
torch.ops.aten.unflatten.int,
431+
torch.ops.aten.gather.default,
431432
torch.ops.aten.index_select.default,
432433
torch.ops.aten.index.Tensor,
433434
# Neg operator flips the range, but keps the magnitude the same.

0 commit comments

Comments
 (0)