Skip to content

Commit 7060fc5

Browse files
Arm backend: Replace asserts with report_reject in operator_support (#13985)
Remove redundant check `node.target in self.targets`, as well as replacing asserts with proper report_reject. This way the graph won't be stopped from lowering, but the operators will instead end up on CPU. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent b02db12 commit 7060fc5

File tree

1 file changed

+43
-14
lines changed

1 file changed

+43
-14
lines changed

backends/arm/operator_support/to_copy_support.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ def _merge_supported_types(
7676
def is_node_tosa_supported(
7777
self, node: fx.Node, tosa_spec: TosaSpecification
7878
) -> bool:
79-
assert node.target in self.targets
80-
8179
supported_dtypes = (
8280
self.ALL_SUPPORTED_TYPES
8381
if tosa_spec.support_float()
@@ -90,10 +88,27 @@ def is_node_tosa_supported(
9088
if v in supported_dtypes
9189
)
9290

93-
# Check input type
94-
assert len(node.all_input_nodes) == 1
91+
if len(node.all_input_nodes) != 1:
92+
self.reporter.report_reject(
93+
node,
94+
(
95+
"Expected exactly one input node, "
96+
f"got {len(node.all_input_nodes)} for {node.target}."
97+
),
98+
)
99+
return False
95100
input_val = node.all_input_nodes[0].meta["val"]
96-
assert isinstance(input_val, torch._subclasses.FakeTensor)
101+
if not isinstance(input_val, torch._subclasses.FakeTensor):
102+
self.reporter.report_reject(
103+
node,
104+
(
105+
"Invalid or missing meta: expected FakeTensor input, got "
106+
f"{type(input_val).__name__} for {node.target}."
107+
),
108+
)
109+
return False
110+
111+
# Check input type
97112
input_dtype = input_val.dtype
98113
if input_dtype not in supported_dtypes:
99114
self.reporter.report_reject(
@@ -104,14 +119,24 @@ def is_node_tosa_supported(
104119

105120
# Check output type
106121
output_val = node.meta["val"]
107-
assert isinstance(output_val, torch._subclasses.FakeTensor)
122+
if not isinstance(output_val, torch._subclasses.FakeTensor):
123+
self.reporter.report_reject(
124+
node,
125+
(
126+
"Invalid or missing meta: expected FakeTensor output, got "
127+
f"{type(output_val).__name__} for {node.target}."
128+
),
129+
)
130+
return False
108131
if output_val.dtype not in supported_dtypes[input_dtype]:
109132
self.reporter.report_reject(
110133
node,
111-
f"Output dtype {output_val.dtype} is not supported in "
112-
f"{node.target} for input dtype {input_dtype}. "
113-
f"Supported output types: "
114-
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}",
134+
(
135+
f"Output dtype {output_val.dtype} is not supported in "
136+
f"{node.target} for input dtype {input_dtype}. "
137+
f"Supported output types: "
138+
f"{', '.join(str(t) for t in supported_dtypes[input_dtype])}"
139+
),
115140
)
116141
return False
117142

@@ -120,8 +145,10 @@ def is_node_tosa_supported(
120145
if node.kwargs["memory_format"] in (torch.preserve_format,):
121146
self.reporter.report_reject(
122147
node,
123-
f"Argument 'memory_format' is not supported for "
124-
f"{node.target} right now.",
148+
(
149+
"Argument 'memory_format' is not supported for "
150+
f"{node.target} right now."
151+
),
125152
)
126153
return False
127154

@@ -132,8 +159,10 @@ def is_node_tosa_supported(
132159
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
133160
self.reporter.report_reject(
134161
node,
135-
f"Argument {dim_order=} is not supported for "
136-
f"{node.target} right now.",
162+
(
163+
f"Argument {dim_order=} is not supported for "
164+
f"{node.target} right now."
165+
),
137166
)
138167
return False
139168

0 commit comments

Comments
 (0)