Skip to content

Commit b8485bc

Browse files
committed
Refactor clone identity check into _is_non_identity_clone
1 parent e14a700 commit b8485bc

File tree

1 file changed

+24
-29
lines changed

1 file changed

+24
-29
lines changed

backends/transforms/remove_clone_ops.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,7 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
3535
if n.target not in self.clone_ops:
3636
continue
3737

38-
# Skip removal of clone ops that modify layout/dim order.
39-
if self.aten_clone_is_non_identity(
40-
n
41-
) or self._clone_dim_order_is_non_identity(n):
38+
if self._is_non_identity_clone(n):
4239
continue
4340

4441
to_be_removed = n
@@ -56,28 +53,26 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
5653
dead_code_elimination_pass(graph_module)
5754
return PassResult(graph_module, True)
5855

59-
def aten_clone_is_non_identity(self, node: torch.fx.Node) -> bool:
60-
"""Return True if aten.clone has modified memory format."""
61-
if node.target != exir_ops.edge.aten.clone.default:
62-
return False
63-
64-
memory_format = node.kwargs.get("memory_format")
65-
if memory_format in (None, torch.preserve_format):
66-
return False
67-
68-
input_meta = node.args[0].meta
69-
return "val" in input_meta and not input_meta["val"].is_contiguous(
70-
memory_format=memory_format
71-
)
72-
73-
def _clone_dim_order_is_non_identity(self, node: torch.fx.Node) -> bool:
74-
"""Return True if _clone_dim_order has modified dim order."""
75-
if node.target != exir_ops.edge.dim_order_ops._clone_dim_order.default:
76-
return False
77-
78-
input_meta = node.args[0].meta
79-
return (
80-
"val" in node.meta
81-
and "val" in input_meta
82-
and node.meta["val"].dim_order() != input_meta["val"].dim_order()
83-
)
56+
def _is_non_identity_clone(self, node: torch.fx.Node) -> bool:
57+
"""Return True if clone has modified memory layout or dim order."""
58+
59+
# aten.clone: check for memory_format changes
60+
if node.target == exir_ops.edge.aten.clone.default:
61+
memory_format = node.kwargs.get("memory_format")
62+
if memory_format in (None, torch.preserve_format):
63+
return False
64+
input_meta = node.args[0].meta
65+
return "val" in input_meta and not input_meta["val"].is_contiguous(
66+
memory_format=memory_format
67+
)
68+
69+
# _clone_dim_order: check for dim_order changes
70+
if node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default:
71+
input_meta = node.args[0].meta
72+
return (
73+
"val" in node.meta
74+
and "val" in input_meta
75+
and node.meta["val"].dim_order() != input_meta["val"].dim_order()
76+
)
77+
78+
return False

0 commit comments

Comments
 (0)