|
27 | 27 | AmbiguousDimOrderError, |
28 | 28 | MemoryFormatOpsPassTestUtils, |
29 | 29 | MemoryFormatTestSet, |
| 30 | + PropagateToCloneChannelsLastModule, |
30 | 31 | PropagateToCopyChannalsLastModule, |
| 32 | + SimpleCloneChannelsLastModule, |
| 33 | + SimpleCloneContiguousModule, |
31 | 34 | SimpleEmptyChannelLastModule, |
32 | 35 | SimpleEmptyContiguoustModule, |
33 | 36 | SimpleToCopyChannelsLastModule, |
@@ -91,6 +94,36 @@ def test_op_empty_replacement_contiguous(self) -> None: |
91 | 94 | ), |
92 | 95 | ) |
93 | 96 |
|
| 97 | + def test_op_clone_replacement_contiguous(self) -> None: |
| 98 | + model = SimpleCloneContiguousModule() |
| 99 | + MemoryFormatOpsPassTestUtils.memory_format_test_runner( |
| 100 | + self, |
| 101 | + MemoryFormatTestSet( |
| 102 | + module=model.eval(), |
| 103 | + op=torch.ops.aten.clone.default, |
| 104 | + sample_input=( |
| 105 | + torch.randn((3, 4, 5, 6)).to(memory_format=torch.channels_last), |
| 106 | + ), |
| 107 | + target_memory_format=torch.contiguous_format, |
| 108 | + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, |
| 109 | + ), |
| 110 | + ) |
| 111 | + |
| 112 | + def test_op_clone_replacement_channels_last(self) -> None: |
| 113 | + model = SimpleCloneChannelsLastModule() |
| 114 | + MemoryFormatOpsPassTestUtils.memory_format_test_runner( |
| 115 | + self, |
| 116 | + MemoryFormatTestSet( |
| 117 | + module=model.eval(), |
| 118 | + op=torch.ops.aten.clone.default, |
| 119 | + sample_input=( |
| 120 | + torch.randn((3, 4, 5, 6)).to(memory_format=torch.contiguous_format), |
| 121 | + ), |
| 122 | + target_memory_format=torch.channels_last, |
| 123 | + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, |
| 124 | + ), |
| 125 | + ) |
| 126 | + |
94 | 127 | def test_op_dim_order_update(self) -> None: |
95 | 128 | MemoryFormatOpsPassTestUtils.memory_format_test_runner( |
96 | 129 | self, |
@@ -128,6 +161,25 @@ def test_op_dim_order_propagation(self) -> None: |
128 | 161 | check_unambiguous_dim_order=True, |
129 | 162 | ) |
130 | 163 |
|
| 164 | + def test_op_clone_dim_order_propagation(self) -> None: |
| 165 | + MemoryFormatOpsPassTestUtils.memory_format_test_runner( |
| 166 | + self, |
| 167 | + MemoryFormatTestSet( |
| 168 | + module=PropagateToCloneChannelsLastModule().eval(), |
| 169 | + op=torch.ops.aten.clone.default, |
| 170 | + sample_input=( |
| 171 | + torch.rand_like( |
| 172 | + torch.zeros([2, 2, 2, 2]), |
| 173 | + dtype=torch.float32, |
| 174 | + memory_format=torch.contiguous_format, |
| 175 | + ), |
| 176 | + ), |
| 177 | + target_memory_format=torch.channels_last, |
| 178 | + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, |
| 179 | + ), |
| 180 | + check_unambiguous_dim_order=True, |
| 181 | + ) |
| 182 | + |
131 | 183 | def test_op_dim_order_propagation_ambiguous(self) -> None: |
132 | 184 | try: |
133 | 185 | MemoryFormatOpsPassTestUtils.memory_format_test_runner( |
|
0 commit comments