2727    MemoryFormatOpsPassTestUtils ,
2828    MemoryFormatTestSet ,
2929    PropagateToCopyChannalsLastModule ,
30+     SimpleEmptyChannelLastModule ,
31+     SimpleEmptyContiguoustModule ,
3032    SimpleToCopyChannelsLastModule ,
3133    SimpleToCopyContiguousModule ,
3234)
@@ -45,6 +47,7 @@ def test_op_to_copy_replacement_2d(self) -> None:
4547            self ,
4648            MemoryFormatTestSet (
4749                module = SimpleToCopyContiguousModule ().eval (),
50+                 op = torch .ops .aten ._to_copy .default ,
4851                sample_input = (torch .randn ([3 , 4 , 5 ], dtype = torch .float32 ),),
4952                target_memory_format = torch .contiguous_format ,
5053                _load_for_executorch_from_buffer = _load_for_executorch_from_buffer ,
@@ -56,17 +59,43 @@ def test_op_to_copy_replacement_4d(self) -> None:
5659            self ,
5760            MemoryFormatTestSet (
5861                module = SimpleToCopyContiguousModule ().eval (),
62+                 op = torch .ops .aten ._to_copy .default ,
5963                sample_input = (torch .randn ([3 , 4 , 5 , 6 ], dtype = torch .float32 ),),
6064                target_memory_format = torch .contiguous_format ,
6165                _load_for_executorch_from_buffer = _load_for_executorch_from_buffer ,
6266            ),
6367        )
6468
69+     def  test_op_empty_replacement_channels_last (self ) ->  None :
70+         MemoryFormatOpsPassTestUtils .memory_format_test_runner (
71+             self ,
72+             MemoryFormatTestSet (
73+                 module = SimpleEmptyChannelLastModule ().eval (),
74+                 op = torch .ops .aten .empty .memory_format ,
75+                 sample_input = (torch .randn ((1 , 10 , 24 , 24 ), dtype = torch .float32 ),),
76+                 target_memory_format = torch .channels_last ,
77+                 _load_for_executorch_from_buffer = _load_for_executorch_from_buffer ,
78+             ),
79+         )
80+ 
81+     def  test_op_empty_replacement_contiguous (self ) ->  None :
82+         MemoryFormatOpsPassTestUtils .memory_format_test_runner (
83+             self ,
84+             MemoryFormatTestSet (
85+                 module = SimpleEmptyContiguoustModule ().eval (),
86+                 op = torch .ops .aten .empty .memory_format ,
87+                 sample_input = (torch .randn ((1 , 10 , 24 , 24 ), dtype = torch .float32 ),),
88+                 target_memory_format = torch .contiguous_format ,
89+                 _load_for_executorch_from_buffer = _load_for_executorch_from_buffer ,
90+             ),
91+         )
92+ 
6593    def  test_op_dim_order_update (self ) ->  None :
6694        MemoryFormatOpsPassTestUtils .memory_format_test_runner (
6795            self ,
6896            MemoryFormatTestSet (
6997                module = SimpleToCopyChannelsLastModule ().eval (),
98+                 op = torch .ops .aten ._to_copy .default ,
7099                sample_input = (
71100                    torch .rand_like (
72101                        torch .zeros ([2 , 2 , 2 , 2 ]),
@@ -84,6 +113,7 @@ def test_op_dim_order_propagation(self) -> None:
84113            self ,
85114            MemoryFormatTestSet (
86115                module = PropagateToCopyChannalsLastModule ().eval (),
116+                 op = torch .ops .aten ._to_copy .default ,
87117                sample_input = (
88118                    torch .rand_like (
89119                        torch .zeros ([2 , 2 , 2 , 2 ]),
@@ -273,6 +303,7 @@ def test_resnet18(self) -> None:
273303            self ,
274304            MemoryFormatTestSet (
275305                module = model .eval (),
306+                 op = torch .ops .aten ._to_copy .default ,
276307                sample_input = (torch .randn (1 , 3 , 224 , 224 ),),
277308                target_memory_format = torch .contiguous_format ,
278309                op_level_check = False ,
@@ -288,6 +319,7 @@ def test_resnet18_xnnpack(self) -> None:
288319            self ,
289320            MemoryFormatTestSet (
290321                module = model .eval (),
322+                 op = torch .ops .aten ._to_copy .default ,
291323                sample_input = (torch .randn (1 , 3 , 224 , 224 ),),
292324                target_memory_format = torch .contiguous_format ,
293325                op_level_check = False ,
@@ -304,6 +336,7 @@ def test_mobilenet_v3(self) -> None:
304336            self ,
305337            MemoryFormatTestSet (
306338                module = model .eval (),
339+                 op = torch .ops .aten ._to_copy .default ,
307340                sample_input = (torch .randn (1 , 3 , 224 , 224 ),),
308341                target_memory_format = torch .contiguous_format ,
309342                op_level_check = False ,
@@ -319,6 +352,7 @@ def test_mobilenet_v3_xnnpack(self) -> None:
319352            self ,
320353            MemoryFormatTestSet (
321354                module = model .eval (),
355+                 op = torch .ops .aten ._to_copy .default ,
322356                sample_input = (torch .randn (1 , 3 , 224 , 224 ),),
323357                target_memory_format = torch .contiguous_format ,
324358                op_level_check = False ,
0 commit comments