|
28 | 28 | MemoryFormatOpsPassTestUtils, |
29 | 29 | MemoryFormatTestSet, |
30 | 30 | PropagateToCopyChannalsLastModule, |
31 | | - SimpleCloneChannelsLastModule, |
32 | | - SimpleCloneContiguousModule, |
33 | 31 | SimpleEmptyChannelLastModule, |
34 | 32 | SimpleEmptyContiguoustModule, |
35 | 33 | SimpleToCopyChannelsLastModule, |
@@ -93,36 +91,40 @@ def test_op_empty_replacement_contiguous(self) -> None: |
93 | 91 | ), |
94 | 92 | ) |
95 | 93 |
|
96 | | - def test_op_clone_replacement_contiguous(self) -> None: |
97 | | - model = SimpleCloneContiguousModule() |
98 | | - MemoryFormatOpsPassTestUtils.memory_format_test_runner( |
99 | | - self, |
100 | | - MemoryFormatTestSet( |
101 | | - module=model.eval(), |
102 | | - op=torch.ops.aten.clone.default, |
103 | | - sample_input=( |
104 | | - torch.randn((3, 4, 5, 6)).to(memory_format=torch.channels_last), |
105 | | - ), |
106 | | - target_memory_format=torch.contiguous_format, |
107 | | - _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, |
108 | | - ), |
| 94 | + def test_op_clone_dim_order_preserves_channels_last(self): |
| 95 | + x = torch.randn(2, 3, 4, 5).to(memory_format=torch.channels_last) |
| 96 | + y = torch.ops.dim_order_ops._clone_dim_order.default(x) |
| 97 | + |
| 98 | + assert y.is_contiguous( |
| 99 | + memory_format=torch.channels_last |
| 100 | + ), "_clone_dim_order output is not in channels_last memory format." |
| 101 | + assert torch.allclose(x, y) |
| 102 | + |
| 103 | + def test_op_clone_dim_order_to_contiguous(self): |
| 104 | + x = torch.randn(2, 3, 4, 5).to(memory_format=torch.channels_last) |
| 105 | + contiguous_dim_order = get_dim_order(torch.contiguous_format, x.dim()) |
| 106 | + y = torch.ops.dim_order_ops._clone_dim_order.default( |
| 107 | + x, dim_order=contiguous_dim_order |
109 | 108 | ) |
110 | 109 |
|
111 | | - def test_op_clone_replacement_channels_last(self) -> None: |
112 | | - model = SimpleCloneChannelsLastModule() |
113 | | - MemoryFormatOpsPassTestUtils.memory_format_test_runner( |
114 | | - self, |
115 | | - MemoryFormatTestSet( |
116 | | - module=model.eval(), |
117 | | - op=torch.ops.aten.clone.default, |
118 | | - sample_input=( |
119 | | - torch.randn((3, 4, 5, 6)).to(memory_format=torch.contiguous_format), |
120 | | - ), |
121 | | - target_memory_format=torch.channels_last, |
122 | | - _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, |
123 | | - ), |
| 110 | + assert ( |
| 111 | + y.is_contiguous() |
| 112 | + ), "_clone_dim_order output is not in contiguous memory format" |
| 113 | + assert torch.allclose(x, y) |
| 114 | + |
| 115 | + def test_op_clone_dim_order_out_to_channels_last(self): |
| 116 | + x = torch.randn(2, 3, 4, 5).contiguous() |
| 117 | + y = torch.empty_like(x, memory_format=torch.channels_last) |
| 118 | + channels_last_dim_order = get_dim_order(torch.channels_last, y.dim()) |
| 119 | + torch.ops.dim_order_ops._clone_dim_order.out( |
| 120 | + x, dim_order=channels_last_dim_order, out=y |
124 | 121 | ) |
125 | 122 |
|
| 123 | + assert y.is_contiguous( |
| 124 | + memory_format=torch.channels_last |
| 125 | + ), "_clone_dim_order output is not in channels_last memory format" |
| 126 | + assert torch.allclose(x, y) |
| 127 | + |
126 | 128 | def test_op_dim_order_update(self) -> None: |
127 | 129 | MemoryFormatOpsPassTestUtils.memory_format_test_runner( |
128 | 130 | self, |
|
0 commit comments