Skip to content

Commit f9f9515

Browse files
committed
Remove input dtype gating and add memory_format check
1 parent 74e2cce commit f9f9515

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

backends/arm/operator_support/clone_dim_order_support.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,11 @@ def is_node_tosa_supported(
3535
) -> bool:
3636
assert node.target in self.targets
3737

38-
supported_dtypes = {torch.bool, torch.int8, torch.int16, torch.int32}
39-
if tosa_spec.support_float():
40-
supported_dtypes |= {torch.bfloat16, torch.float16, torch.float32}
41-
4238
# Check input type
4339
assert len(node.all_input_nodes) == 1
4440
input_val = node.all_input_nodes[0].meta["val"]
4541
assert isinstance(input_val, torch._subclasses.FakeTensor)
4642
input_dtype = input_val.dtype
47-
if input_dtype not in supported_dtypes:
48-
self.reporter.report_reject(
49-
node,
50-
f"Input dtype {input_val.dtype} is not supported in {node.target}.",
51-
)
52-
return False
5343

5444
# Check output type
5545
output_val = node.meta["val"]
@@ -61,6 +51,16 @@ def is_node_tosa_supported(
6151
)
6252
return False
6353

54+
# Check memory format
55+
if "memory_format" in node.kwargs:
56+
if node.kwargs["memory_format"] in (torch.preserve_format,):
57+
self.reporter.report_reject(
58+
node,
59+
f"Argument 'memory_format' is not supported for "
60+
f"{node.target} right now.",
61+
)
62+
return False
63+
6464
# Check dim_order
6565
if "dim_order" in node.kwargs:
6666
dim_order = node.kwargs["dim_order"]

0 commit comments

Comments
 (0)