Skip to content

Commit 6f11aa4

Browse files
committed
Remove meandim special case from NeedsDecompositionCheck
Signed-off-by: Adrian Lundell <[email protected]>
1 parent fa77f2f commit 6f11aa4

File tree

1 file changed

+17
-22
lines changed

1 file changed

+17
-22
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -262,28 +262,23 @@ def is_node_supported(
262262

263263
if node.op != "call_function":
264264
return True
265-
if node.target == exir_ops.edge.aten.mean.dim:
266-
dim = node.args[1]
267-
needs_decomp = dim != [-1, -2]
268-
else:
269-
needs_decomp = node.target in [
270-
exir_ops.edge.aten.div.Tensor,
271-
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
272-
exir_ops.edge.aten.native_layer_norm.default,
273-
exir_ops.edge.aten.mean.dim,
274-
exir_ops.edge.aten._softmax.default,
275-
exir_ops.edge.aten._log_softmax.default,
276-
exir_ops.edge.aten.var.correction,
277-
exir_ops.edge.aten.var.dim,
278-
exir_ops.edge.aten.add.Scalar,
279-
exir_ops.edge.aten.sqrt.default,
280-
exir_ops.edge.aten.sub.Scalar,
281-
exir_ops.edge.aten.mul.Scalar,
282-
exir_ops.edge.aten.ne.Tensor,
283-
exir_ops.edge.aten.ne.Scalar,
284-
exir_ops.edge.aten.div.Scalar,
285-
exir_ops.edge.aten.leaky_relu.default,
286-
]
265+
needs_decomp = node.target in [
266+
exir_ops.edge.aten.div.Tensor,
267+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
268+
exir_ops.edge.aten.native_layer_norm.default,
269+
exir_ops.edge.aten._softmax.default,
270+
exir_ops.edge.aten._log_softmax.default,
271+
exir_ops.edge.aten.var.correction,
272+
exir_ops.edge.aten.var.dim,
273+
exir_ops.edge.aten.add.Scalar,
274+
exir_ops.edge.aten.sqrt.default,
275+
exir_ops.edge.aten.sub.Scalar,
276+
exir_ops.edge.aten.mul.Scalar,
277+
exir_ops.edge.aten.ne.Tensor,
278+
exir_ops.edge.aten.ne.Scalar,
279+
exir_ops.edge.aten.div.Scalar,
280+
exir_ops.edge.aten.leaky_relu.default,
281+
]
287282
if needs_decomp:
288283
self.reporter.report_reject(node, "Needs to be decomposed.")
289284
return False

0 commit comments

Comments
 (0)