|
6 | 6 | # pyre-unsafe
|
7 | 7 |
|
8 | 8 | import typing
|
| 9 | +from typing import cast |
9 | 10 |
|
10 | 11 | import torch
|
11 | 12 | import torch.fx as fx
|
| 13 | + |
12 | 14 | from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
|
13 | 15 | from executorch.backends.arm._passes.insert_table_ops import TableOps
|
14 | 16 | from executorch.backends.arm.operators.op_permute import transform_permutation_vector
|
15 | 17 | from executorch.backends.arm.tosa_utils import tosa_shape
|
16 | 18 | from executorch.exir.backend.utils import WhyNoPartitionReporter
|
17 |
| - |
18 | 19 | from executorch.exir.dialects._ops import ops as exir_ops
|
19 | 20 | from torch.fx.passes.operator_support import OperatorSupportBase
|
20 | 21 |
|
@@ -212,19 +213,51 @@ def is_node_supported(
|
212 | 213 | Returns:
|
213 | 214 | False if the operator is not support and True if it is supported.
|
214 | 215 | """
|
215 |
| - if not node.target == exir_ops.edge.aten.view_copy.default: |
| 216 | + # Select decomposes into squeeze, which in turn becomes a view. Therefore, |
| 217 | + # perform the same check on select operators as view operators. |
| 218 | + if node.target not in ( |
| 219 | + exir_ops.edge.aten.view_copy.default, |
| 220 | + exir_ops.edge.aten.select.int, |
| 221 | + exir_ops.edge.aten.select_copy.int, |
| 222 | + ): |
216 | 223 | return True
|
217 | 224 |
|
218 |
| - shape = list(get_first_fake_tensor(node).shape) |
| 225 | + if node.target in ( |
| 226 | + exir_ops.edge.aten.select.int, |
| 227 | + exir_ops.edge.aten.select_copy.int, |
| 228 | + ): |
| 229 | + input_node, dim, index = cast(tuple[fx.Node, int, int], node.args) |
| 230 | + |
| 231 | + shape = input_node.meta["val"].shape |
| 232 | + rank = len(shape) |
| 233 | + if not -rank <= dim < rank: |
| 234 | + raise IndexError( |
| 235 | + f"Dim {dim} is outside of the range for tensor '{node.target}' of " |
| 236 | + f"rank {rank}" |
| 237 | + ) |
| 238 | + dim = dim % rank |
| 239 | + |
| 240 | + size = shape[dim] |
| 241 | + if not -size <= index < size: |
| 242 | + raise IndexError( |
| 243 | + f"Index {index} is outside of the range for dim {dim} with size " |
| 244 | + f"{size} for tensor {node.target}" |
| 245 | + ) |
| 246 | + index = index % size |
| 247 | + |
| 248 | + # Shape after squeeze. This may get converted into a view which may become |
| 249 | + # a transpose. This is why we're checking select. |
| 250 | + squeezed_shape = shape[:dim] + shape[dim + 1 :] |
| 251 | + shape = squeezed_shape |
| 252 | + else: |
| 253 | + shape = list(get_first_fake_tensor(node).shape) |
| 254 | + |
219 | 255 | dtype = _try_determine_dtype(node)
|
220 |
| - permutation = list(typing.cast(list[int], node.args[1])) |
221 | 256 |
|
222 | 257 | rank = len(shape)
|
223 | 258 | if rank > 4:
|
224 | 259 | if dtype == torch.int32:
|
225 |
| - self.reporter.report_reject( |
226 |
| - node, f"No support for {permutation=} in int32." |
227 |
| - ) |
| 260 | + self.reporter.report_reject(node, "No support for rank > 4 in int32.") |
228 | 261 | return False
|
229 | 262 |
|
230 | 263 | if dtype in (torch.int8, torch.int16):
|
|
0 commit comments