Skip to content

Commit b76cf7b

Browse files
committed
Refactor clone_dim_order_support to existing clone_support file
1 parent e8ceb5d commit b76cf7b

File tree

3 files changed

+54
-78
lines changed

3 files changed

+54
-78
lines changed

backends/arm/operator_support/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# pyre-unsafe
77

88
from . import ( # noqa
9-
clone_dim_order_support,
109
clone_support,
1110
convolution_support,
1211
embedding_support,

backends/arm/operator_support/clone_dim_order_support.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

backends/arm/operator_support/clone_support.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import logging
77

8+
import torch
89
import torch.fx as fx
910
from executorch.backends.arm.operator_support.tosa_supported_operators import (
1011
register_tosa_support_check,
@@ -18,7 +19,7 @@
1819

1920
@register_tosa_support_check
2021
class CloneSupported(SupportedTOSAOperatorCheck):
21-
targets = [exir_ops.edge.aten.clone.default]
22+
targets = [exir_ops.edge.dim_order_ops._clone_dim_order.default]
2223

2324
tosa_specs = [
2425
TosaSpecification.create_from_string("TOSA-1.0+INT"),
@@ -28,10 +29,62 @@ class CloneSupported(SupportedTOSAOperatorCheck):
2829
def is_node_tosa_supported(
2930
self, node: fx.Node, tosa_spec: TosaSpecification
3031
) -> bool:
32+
if node.target not in self.targets:
33+
self.reporter.report_reject(node, f"Target {node.target} is not supported.")
34+
return False
3135

3236
input_node = node.args[0]
3337
if not isinstance(input_node, fx.Node):
3438
self.reporter.report_reject(node, "Non tensor clones are not supported")
3539
return False
3640

41+
# Check input node
42+
if len(node.all_input_nodes) != 1:
43+
self.reporter.report_reject(
44+
node, f"Expected 1 input node, got {len(node.all_input_nodes)}"
45+
)
46+
return False
47+
48+
input_val = node.all_input_nodes[0].meta["val"]
49+
if not isinstance(input_val, torch._subclasses.FakeTensor):
50+
self.reporter.report_reject(node, "Expected input to be a FakeTensor.")
51+
return False
52+
53+
input_dtype = input_val.dtype
54+
55+
# Check output node
56+
output_val = node.meta["val"]
57+
if not isinstance(output_val, torch._subclasses.FakeTensor):
58+
self.reporter.report_reject(node, "Expected output to be a FakeTensor.")
59+
return False
60+
61+
if output_val.dtype != input_dtype:
62+
self.reporter.report_reject(
63+
node,
64+
f"Input dtype {input_val.dtype} does not match {output_val.dtype}.",
65+
)
66+
return False
67+
68+
# Check memory format
69+
if "memory_format" in node.kwargs:
70+
if node.kwargs["memory_format"] in (torch.preserve_format,):
71+
self.reporter.report_reject(
72+
node,
73+
f"Argument 'memory_format' is not supported for "
74+
f"{node.target} right now.",
75+
)
76+
return False
77+
78+
# Check dim_order
79+
if "dim_order" in node.kwargs:
80+
dim_order = node.kwargs["dim_order"]
81+
# pyre-ignore[6]
82+
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
83+
self.reporter.report_reject(
84+
node,
85+
f"Argument {dim_order=} is not supported for "
86+
f"{node.target} right now.",
87+
)
88+
return False
89+
3790
return True

0 commit comments

Comments
 (0)