@@ -35,10 +35,7 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
3535 if n .target not in self .clone_ops :
3636 continue
3737
38- # Skip removal of clone ops that modify layout/dim order.
39- if self .aten_clone_is_non_identity (
40- n
41- ) or self ._clone_dim_order_is_non_identity (n ):
38+ if self ._is_non_identity_clone (n ):
4239 continue
4340
4441 to_be_removed = n
@@ -56,28 +53,26 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
5653 dead_code_elimination_pass (graph_module )
5754 return PassResult (graph_module , True )
5855
59- def aten_clone_is_non_identity (self , node : torch .fx .Node ) -> bool :
60- """Return True if aten.clone has modified memory format."""
61- if node .target != exir_ops .edge .aten .clone .default :
62- return False
63-
64- memory_format = node .kwargs .get ("memory_format" )
65- if memory_format in (None , torch .preserve_format ):
66- return False
67-
68- input_meta = node .args [0 ].meta
69- return "val" in input_meta and not input_meta ["val" ].is_contiguous (
70- memory_format = memory_format
71- )
72-
73- def _clone_dim_order_is_non_identity (self , node : torch .fx .Node ) -> bool :
74- """Return True if _clone_dim_order has modified dim order."""
75- if node .target != exir_ops .edge .dim_order_ops ._clone_dim_order .default :
76- return False
77-
78- input_meta = node .args [0 ].meta
79- return (
80- "val" in node .meta
81- and "val" in input_meta
82- and node .meta ["val" ].dim_order () != input_meta ["val" ].dim_order ()
83- )
56+ def _is_non_identity_clone (self , node : torch .fx .Node ) -> bool :
57+ """Return True if clone has modified memory layout or dim order."""
58+
59+ # aten.clone: check for memory_format changes
60+ if node .target == exir_ops .edge .aten .clone .default :
61+ memory_format = node .kwargs .get ("memory_format" )
62+ if memory_format in (None , torch .preserve_format ):
63+ return False
64+ input_meta = node .args [0 ].meta
65+ return "val" in input_meta and not input_meta ["val" ].is_contiguous (
66+ memory_format = memory_format
67+ )
68+
69+ # _clone_dim_order: check for dim_order changes
70+ if node .target == exir_ops .edge .dim_order_ops ._clone_dim_order .default :
71+ input_meta = node .args [0 ].meta
72+ return (
73+ "val" in node .meta
74+ and "val" in input_meta
75+ and node .meta ["val" ].dim_order () != input_meta ["val" ].dim_order ()
76+ )
77+
78+ return False
0 commit comments