2727 MemoryFormatOpsPassTestUtils ,
2828 MemoryFormatTestSet ,
2929 PropagateToCopyChannalsLastModule ,
30+ SimpleEmptyChannelLastModule ,
31+ SimpleEmptyContiguoustModule ,
3032 SimpleToCopyChannelsLastModule ,
3133 SimpleToCopyContiguousModule ,
3234)
@@ -45,6 +47,7 @@ def test_op_to_copy_replacement_2d(self) -> None:
4547 self ,
4648 MemoryFormatTestSet (
4749 module = SimpleToCopyContiguousModule ().eval (),
50+ op = torch .ops .aten ._to_copy .default ,
4851 sample_input = (torch .randn ([3 , 4 , 5 ], dtype = torch .float32 ),),
4952 target_memory_format = torch .contiguous_format ,
5053 _load_for_executorch_from_buffer = _load_for_executorch_from_buffer ,
@@ -56,17 +59,43 @@ def test_op_to_copy_replacement_4d(self) -> None:
5659 self ,
5760 MemoryFormatTestSet (
5861 module = SimpleToCopyContiguousModule ().eval (),
62+ op = torch .ops .aten ._to_copy .default ,
5963 sample_input = (torch .randn ([3 , 4 , 5 , 6 ], dtype = torch .float32 ),),
6064 target_memory_format = torch .contiguous_format ,
6165 _load_for_executorch_from_buffer = _load_for_executorch_from_buffer ,
6266 ),
6367 )
6468
69+ def test_op_empty_replacement_channels_last (self ) -> None :
70+ MemoryFormatOpsPassTestUtils .memory_format_test_runner (
71+ self ,
72+ MemoryFormatTestSet (
73+ module = SimpleEmptyChannelLastModule ().eval (),
74+ op = torch .ops .aten .empty .memory_format ,
75+ sample_input = (torch .randn ((1 , 10 , 24 , 24 ), dtype = torch .float32 ),),
76+ target_memory_format = torch .channels_last ,
77+ _load_for_executorch_from_buffer = _load_for_executorch_from_buffer ,
78+ ),
79+ )
80+
81+ def test_op_empty_replacement_contiguous (self ) -> None :
82+ MemoryFormatOpsPassTestUtils .memory_format_test_runner (
83+ self ,
84+ MemoryFormatTestSet (
85+ module = SimpleEmptyContiguoustModule ().eval (),
86+ op = torch .ops .aten .empty .memory_format ,
87+ sample_input = (torch .randn ((1 , 10 , 24 , 24 ), dtype = torch .float32 ),),
88+ target_memory_format = torch .contiguous_format ,
89+ _load_for_executorch_from_buffer = _load_for_executorch_from_buffer ,
90+ ),
91+ )
92+
6593 def test_op_dim_order_update (self ) -> None :
6694 MemoryFormatOpsPassTestUtils .memory_format_test_runner (
6795 self ,
6896 MemoryFormatTestSet (
6997 module = SimpleToCopyChannelsLastModule ().eval (),
98+ op = torch .ops .aten ._to_copy .default ,
7099 sample_input = (
71100 torch .rand_like (
72101 torch .zeros ([2 , 2 , 2 , 2 ]),
@@ -84,6 +113,7 @@ def test_op_dim_order_propagation(self) -> None:
84113 self ,
85114 MemoryFormatTestSet (
86115 module = PropagateToCopyChannalsLastModule ().eval (),
116+ op = torch .ops .aten ._to_copy .default ,
87117 sample_input = (
88118 torch .rand_like (
89119 torch .zeros ([2 , 2 , 2 , 2 ]),
@@ -273,6 +303,7 @@ def test_resnet18(self) -> None:
273303 self ,
274304 MemoryFormatTestSet (
275305 module = model .eval (),
306+ op = torch .ops .aten ._to_copy .default ,
276307 sample_input = (torch .randn (1 , 3 , 224 , 224 ),),
277308 target_memory_format = torch .contiguous_format ,
278309 op_level_check = False ,
@@ -288,6 +319,7 @@ def test_resnet18_xnnpack(self) -> None:
288319 self ,
289320 MemoryFormatTestSet (
290321 module = model .eval (),
322+ op = torch .ops .aten ._to_copy .default ,
291323 sample_input = (torch .randn (1 , 3 , 224 , 224 ),),
292324 target_memory_format = torch .contiguous_format ,
293325 op_level_check = False ,
@@ -304,6 +336,7 @@ def test_mobilenet_v3(self) -> None:
304336 self ,
305337 MemoryFormatTestSet (
306338 module = model .eval (),
339+ op = torch .ops .aten ._to_copy .default ,
307340 sample_input = (torch .randn (1 , 3 , 224 , 224 ),),
308341 target_memory_format = torch .contiguous_format ,
309342 op_level_check = False ,
@@ -319,6 +352,7 @@ def test_mobilenet_v3_xnnpack(self) -> None:
319352 self ,
320353 MemoryFormatTestSet (
321354 module = model .eval (),
355+ op = torch .ops .aten ._to_copy .default ,
322356 sample_input = (torch .randn (1 , 3 , 224 , 224 ),),
323357 target_memory_format = torch .contiguous_format ,
324358 op_level_check = False ,
0 commit comments