2727 AmbiguousDimOrderError ,
2828 MemoryFormatOpsPassTestUtils ,
2929 MemoryFormatTestSet ,
30+ PropagateToCloneChannelsLastModule ,
3031 PropagateToCopyChannalsLastModule ,
3132 SimpleCloneChannelsLastModule ,
33+ SimpleCloneContiguousModule ,
3234 SimpleEmptyChannelLastModule ,
3335 SimpleEmptyContiguoustModule ,
3436 SimpleToCopyChannelsLastModule ,
@@ -92,6 +94,36 @@ def test_op_empty_replacement_contiguous(self) -> None:
9294 ),
9395 )
9496
97+ def test_op_clone_replacement_contiguous (self ) -> None :
98+ model = SimpleCloneContiguousModule ()
99+ MemoryFormatOpsPassTestUtils .memory_format_test_runner (
100+ self ,
101+ MemoryFormatTestSet (
102+ module = model .eval (),
103+ op = torch .ops .aten .clone .default ,
104+ sample_input = (
105+ torch .randn ((3 , 4 , 5 , 6 )).to (memory_format = torch .channels_last ),
106+ ),
107+ target_memory_format = torch .contiguous_format ,
108+ _load_for_executorch_from_buffer = _load_for_executorch_from_buffer ,
109+ ),
110+ )
111+
112+ def test_op_clone_replacement_channels_last (self ) -> None :
113+ model = SimpleCloneChannelsLastModule ()
114+ MemoryFormatOpsPassTestUtils .memory_format_test_runner (
115+ self ,
116+ MemoryFormatTestSet (
117+ module = model .eval (),
118+ op = torch .ops .aten .clone .default ,
119+ sample_input = (
120+ torch .randn ((3 , 4 , 5 , 6 )).to (memory_format = torch .contiguous_format ),
121+ ),
122+ target_memory_format = torch .channels_last ,
123+ _load_for_executorch_from_buffer = _load_for_executorch_from_buffer ,
124+ ),
125+ )
126+
95127 def test_op_dim_order_update (self ) -> None :
96128 MemoryFormatOpsPassTestUtils .memory_format_test_runner (
97129 self ,
@@ -129,6 +161,25 @@ def test_op_dim_order_propagation(self) -> None:
129161 check_unambiguous_dim_order = True ,
130162 )
131163
164+ def test_op_clone_dim_order_propagation (self ) -> None :
165+ MemoryFormatOpsPassTestUtils .memory_format_test_runner (
166+ self ,
167+ MemoryFormatTestSet (
168+ module = PropagateToCloneChannelsLastModule ().eval (),
169+ op = torch .ops .aten .clone .default ,
170+ sample_input = (
171+ torch .rand_like (
172+ torch .zeros ([2 , 2 , 2 , 2 ]),
173+ dtype = torch .float32 ,
174+ memory_format = torch .contiguous_format ,
175+ ),
176+ ),
177+ target_memory_format = torch .channels_last ,
178+ _load_for_executorch_from_buffer = _load_for_executorch_from_buffer ,
179+ ),
180+ check_unambiguous_dim_order = True ,
181+ )
182+
132183 def test_op_dim_order_propagation_ambiguous (self ) -> None :
133184 try :
134185 MemoryFormatOpsPassTestUtils .memory_format_test_runner (
@@ -154,6 +205,29 @@ def test_op_dim_order_propagation_ambiguous(self) -> None:
154205 except AmbiguousDimOrderError :
155206 pass # Expected error
156207
208+ def test_op_clone_dim_order_graph_replacement (self ):
209+ model = SimpleCloneChannelsLastModule ()
210+ x = torch .randn (3 , 4 , 5 , 6 ).to (memory_format = torch .contiguous_format )
211+ _clone_dim_order_op_str = (
212+ "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
213+ )
214+
215+ exported = export (model .eval (), (x ,), strict = True )
216+ epm = to_edge (exported , compile_config = EdgeCompileConfig (_skip_dim_order = False ))
217+
218+ # Verify one _clone_dim_order op exists and aten.clone.default nodes have been removed.
219+ (
220+ FileCheck ()
221+ .check_not (
222+ "aten.clone.default"
223+ ) # Check before first _clone_dim_order_op_str match.
224+ .check_count (_clone_dim_order_op_str , 1 , exactly = True )
225+ .check_not (
226+ "aten.clone.default"
227+ ) # Check after _clone_dim_order_op_str match.
228+ .run (epm .exported_program ().graph_module .code )
229+ )
230+
157231 # Only test dim order replacement result in lean mode test.
158232 # This test is irrelevant with operator mode.
159233 def test_dim_order_replacement (self ) -> None :
@@ -390,17 +464,3 @@ def test_mobilenet_v3_xnnpack(self) -> None:
390464 rtol = 1e-3 ,
391465 ),
392466 )
393-
394- def test_op_clone_dim_order_registration (self ):
395- model = SimpleCloneChannelsLastModule ()
396- x = torch .randn (3 , 4 , 5 , 6 ).to (memory_format = torch .contiguous_format )
397- clone_dim_order_op_str = (
398- "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
399- )
400-
401- exported = export (model .eval (), (x ,), strict = True )
402- epm = to_edge (exported , compile_config = EdgeCompileConfig (_skip_dim_order = False ))
403-
404- FileCheck ().check_count (clone_dim_order_op_str , 1 , exactly = True ).run (
405- epm .exported_program ().graph_module .code
406- )
0 commit comments