Skip to content

Commit 8c116d0

Browse files
Sebastian-LarssonStrycekSimon
authored andcommitted
Arm backend: Replace asserts/raises with reporter rejects (pytorch#14371)
- embedding_support: replace input-count assert with reporter.report_reject + return False - index_tensor_support: add explicit rejects for None in indices, rank >= 4 indexing tensors, and int32 overflow of value tensor; previously returned False without explanation - minmax_support: add reject when min/max.dim’s argmax output is used - ethos_u55_support: replace IndexError raises in view/select checks (invalid dim/index) with reporter.report_reject + return False - Improves partition diagnostics and avoids hard crashes Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 9805fde commit 8c116d0

File tree

4 files changed

+40
-11
lines changed

4 files changed

+40
-11
lines changed

backends/arm/operator_support/embedding_support.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,16 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck):
2727
def is_node_tosa_supported(
2828
self, node: fx.Node, tosa_spec: TosaSpecification
2929
) -> bool: # type: ignore[override, misc]
30-
# Note aten.embedding.default requires int64 indices and TOSA does not support it.
31-
# Int32 indices here for aten.embedding.default is ok since it will be decomposed into ops that can handle it.
32-
assert (
33-
len(node.all_input_nodes) == 2
34-
), "Number of inputs to aten.embedding is not 2"
30+
# Note aten.embedding.default requires int64 indices and TOSA does not
31+
# support it. Int32 indices here for aten.embedding.default is ok since
32+
# it will be decomposed into ops that can handle it.
33+
34+
if len(node.all_input_nodes) != 2:
35+
self.reporter.report_reject(
36+
node,
37+
(f"Expected exactly two input nodes, got {len(node.all_input_nodes)}"),
38+
)
39+
return False
3540
indices_val = node.all_input_nodes[1].meta["val"]
3641
indices_dtype = indices_val.dtype
3742

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,18 +236,20 @@ def is_node_supported(
236236
shape = input_node.meta["val"].shape
237237
rank = len(shape)
238238
if not -rank <= dim < rank:
239-
raise IndexError(
240-
f"Dim {dim} is outside of the range for tensor '{node.target}' of "
241-
f"rank {rank}"
239+
self.reporter.report_reject(
240+
node,
241+
(f"Dimension {dim} out of range for rank {rank}."),
242242
)
243+
return False
243244
dim = dim % rank
244245

245246
size = shape[dim]
246247
if not -size <= index < size:
247-
raise IndexError(
248-
f"Index {index} is outside of the range for dim {dim} with size "
249-
f"{size} for tensor {node.target}"
248+
self.reporter.report_reject(
249+
node,
250+
(f"Index {index} out of range for dim {dim} with size {size}."),
250251
)
252+
return False
251253
index = index % size
252254

253255
# Shape after squeeze. This may get converted into a view which may become

backends/arm/operator_support/index_tensor_support.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,31 @@ def is_node_tosa_supported(
111111
for index in indices: # type: ignore[union-attr]
112112
# Usage 2 guard
113113
if index is None:
114+
self.reporter.report_reject(
115+
node,
116+
(
117+
"None (from slice/unsqueeze/ellipsis) before an indexing tensor"
118+
" is not supported."
119+
),
120+
)
114121
return False
115122

116123
# Usage 1 guard
117124
fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type]
118125
if len(fake_tensor.size()) > 3:
126+
self.reporter.report_reject(
127+
node,
128+
("Indexing tensors of rank >= 4 is not supported."),
129+
)
119130
return False
120131

121132
# Usage 3 guard
122133
total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type]
123134
if total_vals > torch.iinfo(torch.int32).max:
135+
self.reporter.report_reject(
136+
node,
137+
("Value size exceeds int32 range; would overflow flattened indexing."),
138+
)
124139
return False
125140

126141
return True

backends/arm/operator_support/minmax_support.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
3232
)
3333

3434
if not (no_argmax or no_argmax_users):
35+
self.reporter.report_reject(
36+
node,
37+
(
38+
"Using the indices output is not supported; only usage of the "
39+
"values output is supported."
40+
),
41+
)
3542
return False
3643

3744
return True

0 commit comments

Comments
 (0)