|
7 | 7 | import unittest |
8 | 8 |
|
9 | 9 | import torch |
| 10 | +from backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass |
10 | 11 | from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import ( |
11 | 12 | ChannelsLastTaggedReshapePass, |
12 | 13 | ) |
@@ -58,48 +59,87 @@ def setUp(self): |
58 | 59 | # .run_method_and_compare_outputs() |
59 | 60 | # ) |
60 | 61 |
|
61 | | - # def test_channels_last_input_graph_transformation(self): |
62 | | - # # Define a simple module for testing |
63 | | - # class SimpleModule(torch.nn.Module): |
64 | | - # def __init__(self): |
65 | | - # super().__init__() |
66 | | - # self.conv = torch.nn.Conv2d(3, 3, 3) |
67 | | - # def forward(self, x): |
68 | | - # return self.conv(x) |
69 | | - # # Create a tester instance with NHWC input |
70 | | - # tester = Tester(SimpleModule().eval(), (torch.randn(1, 3, 3, 3).to(memory_format=torch.channels_last),)) |
71 | | - # # Run the export and pass stages |
72 | | - # tester.export().to_edge().run_passes(self.PassStage) |
73 | | - # # Check the graph for expected nodes |
74 | | - # tester.check_count({ |
75 | | - # "executorch_exir_dialects_edge__ops_aten__to_copy_default": 2, # should be 1 but its 2 |
76 | | - # "executorch_exir_dialects_edge__ops_aten_convolution_default": 1 |
77 | | - # }) |
78 | | - # tester.dump_artifact() |
79 | | - |
80 | | - def test_nhwc_input(self): |
81 | | - class SimpleModule(torch.nn.Module): |
82 | | - def __init__(self): |
83 | | - super().__init__() |
84 | | - self.conv = torch.nn.Conv2d(3, 3, 3) |
85 | | - def forward(self, x): |
86 | | - return self.conv(x) |
87 | | - |
88 | | - tester = Tester(SimpleModule().eval(), (torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last),)) |
89 | | - |
90 | | - tester2 = Tester(SimpleModule().eval(), (torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last),)) |
91 | | - tester2.export().to_edge().run_passes(self.PassStage).dump_artifact() |
92 | | - |
93 | | - |
94 | | - tester.export() \ |
95 | | - .to_edge_transform_and_lower() \ |
96 | | - .dump_artifact()\ |
97 | | - .to_executorch() \ |
98 | | - .dump_artifact()\ |
99 | | - .serialize() \ |
100 | | - .run_method_and_compare_outputs() |
| 62 | + class LinearConv(torch.nn.Module): |
| 63 | + def __init__(self): |
| 64 | + super().__init__() |
| 65 | + self.conv1 = torch.nn.Conv2d(3, 3, 3) |
| 66 | + self.linear1 = torch.nn.Linear(4, 3) |
| 67 | + |
| 68 | + def forward(self, x): |
| 69 | + y = self.linear1(x) |
| 70 | + return self.conv1(y) |
| 71 | + |
| 72 | + def test_conv_linear_dim_order_swaps_on_nhwc_input(self): |
| 73 | + tester = Tester( |
| 74 | + self.LinearConv().eval(), |
| 75 | + (torch.randn(1, 3, 6, 4).to(memory_format=torch.channels_last),), |
| 76 | + ) |
| 77 | + |
| 78 | + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() |
| 79 | + |
| 80 | + def test_conv_linear_dim_order_swaps_on_nchw_input(self): |
| 81 | + tester = Tester( |
| 82 | + self.LinearConv().eval(), |
| 83 | + (torch.randn(1, 3, 6, 4),), |
| 84 | + ) |
| 85 | + |
| 86 | + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() |
| 87 | + |
| 88 | + class ConvLinearConv(torch.nn.Module): |
| 89 | + def __init__(self): |
| 90 | + super().__init__() |
| 91 | + self.conv1 = torch.nn.Conv2d(3, 3, 3) |
| 92 | + self.linear1 = torch.nn.Linear(4, 4) |
| 93 | + |
| 94 | + def forward(self, x): |
| 95 | + y = self.conv1(x) |
| 96 | + return self.linear1(y) |
| 97 | + |
| 98 | + def test_linear_conv_dim_order_swaps_on_nhwc_input(self): |
| 99 | + tester = Tester( |
| 100 | + self.ConvLinearConv().eval(), |
| 101 | + (torch.randn(1, 3, 6, 6).to(memory_format=torch.channels_last),), |
| 102 | + ) |
101 | 103 |
|
| 104 | + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() |
102 | 105 |
|
| 106 | + def test_linear_conv_dim_order_swaps_on_nchw_input(self): |
| 107 | + tester = Tester( |
| 108 | + self.ConvLinearConv().eval(), |
| 109 | + (torch.randn(1, 3, 6, 6),), |
| 110 | + ) |
| 111 | + |
| 112 | + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() |
| 113 | + |
| 114 | + class Bilinear(torch.nn.Module): |
| 115 | + def __init__(self): |
| 116 | + super().__init__() |
| 117 | + |
| 118 | + def forward(self, x): |
| 119 | + return torch.nn.functional.interpolate( |
| 120 | + x, scale_factor=2, mode="bilinear", align_corners=True |
| 121 | + ) |
| 122 | + |
| 123 | + def test_nhwc_input_on_nhwc_op(self): |
| 124 | + tester = Tester( |
| 125 | + self.Bilinear().eval(), |
| 126 | + ( |
| 127 | + torch.arange(8) |
| 128 | + .reshape(1, 2, 2, 2) |
| 129 | + .to(torch.float32) |
| 130 | + .to(memory_format=torch.channels_last), |
| 131 | + ), |
| 132 | + ) |
| 133 | + |
| 134 | + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() |
| 135 | + |
| 136 | + def test_nchw_input_on_nhwc_op(self): |
| 137 | + tester = Tester( |
| 138 | + self.Bilinear().eval(), |
| 139 | + (torch.arange(8).reshape(1, 2, 2, 2).to(torch.float32),), |
| 140 | + ) |
| 141 | + |
| 142 | + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() |
103 | 143 |
|
104 | 144 | # def test_qs8_channels_last_tagged_reshape_pass(self): |
105 | 145 | # for module, num_reshape in self.modules.items(): |
@@ -190,45 +230,45 @@ def forward(self, x): |
190 | 230 | return x |
191 | 231 |
|
192 | 232 | # def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self): |
193 | | - # Copy #1 is for input to conv, nchw -> nhwc |
194 | | - # Copy #2 is for conv to _native_batch_norm_legit_no_training, nhwc -> nchw |
195 | | - # Copy #3 is for input to mean, nchw -> nhwc |
196 | | - # Copy #4 is for output, nhwc -> nchw |
197 | | - |
198 | | - # The graph looks like: |
199 | | - # graph(): |
200 | | - # %arg0_1 : [#users=1] = placeholder[target=arg0_1] |
201 | | - # %aten__to_copy_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {memory_format: torch.channels_last}) |
202 | | - # %_param_constant0 : [#users=1] = get_attr[target=_param_constant0] |
203 | | - # %_param_constant1 : [#users=1] = get_attr[target=_param_constant1] |
204 | | - # %aten_convolution_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten__to_copy_default, %_param_constant0, %_param_constant1, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), kwargs = {}) |
205 | | - # %aten__to_copy_default_1 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_convolution_default,), kwargs = {memory_format: torch.contiguous_format}) |
206 | | - # %_param_constant2 : [#users=1] = get_attr[target=_param_constant2] |
207 | | - # %_param_constant3 : [#users=1] = get_attr[target=_param_constant3] |
208 | | - # %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0] |
209 | | - # %_tensor_constant1 : [#users=1] = get_attr[target=_tensor_constant1] |
210 | | - # %aten__native_batch_norm_legit_no_training_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten__to_copy_default_1, %_param_constant2, %_param_constant3, %_tensor_constant0, %_tensor_constant1, 0.1, 1e-05), kwargs = {}) |
211 | | - # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default, 0), kwargs = {}) |
212 | | - # %aten_hardtanh_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem, 0, 6), kwargs = {}) |
213 | | - # %aten__to_copy_default_2 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_hardtanh_default,), kwargs = {memory_format: torch.channels_last}) |
214 | | - # %aten_mean_dim : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mean.dim](args = (%aten__to_copy_default_2, [-1, -2], True), kwargs = {}) |
215 | | - # %aten__to_copy_default_3 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_mean_dim,), kwargs = {memory_format: torch.contiguous_format}) |
216 | | - # return [aten__to_copy_default_3] |
217 | | - # ( |
218 | | - # Tester( |
219 | | - # self.Conv2dBnHardtanhMeanSequenceModule().eval(), |
220 | | - # (torch.randn(1, 1, 6, 6),), |
221 | | - # ) |
222 | | - # .export() |
223 | | - # .to_edge() |
224 | | - # .run_passes(self.PassStage) |
225 | | - # .check_count( |
226 | | - # { |
227 | | - # self.to_copy_name: 4, |
228 | | - # } |
229 | | - # ) |
230 | | - # .run_method_and_compare_outputs() |
231 | | - # ) |
| 233 | + # Copy #1 is for input to conv, nchw -> nhwc |
| 234 | + # Copy #2 is for conv to _native_batch_norm_legit_no_training, nhwc -> nchw |
| 235 | + # Copy #3 is for input to mean, nchw -> nhwc |
| 236 | + # Copy #4 is for output, nhwc -> nchw |
| 237 | + |
| 238 | + # The graph looks like: |
| 239 | + # graph(): |
| 240 | + # %arg0_1 : [#users=1] = placeholder[target=arg0_1] |
| 241 | + # %aten__to_copy_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {memory_format: torch.channels_last}) |
| 242 | + # %_param_constant0 : [#users=1] = get_attr[target=_param_constant0] |
| 243 | + # %_param_constant1 : [#users=1] = get_attr[target=_param_constant1] |
| 244 | + # %aten_convolution_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten__to_copy_default, %_param_constant0, %_param_constant1, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), kwargs = {}) |
| 245 | + # %aten__to_copy_default_1 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_convolution_default,), kwargs = {memory_format: torch.contiguous_format}) |
| 246 | + # %_param_constant2 : [#users=1] = get_attr[target=_param_constant2] |
| 247 | + # %_param_constant3 : [#users=1] = get_attr[target=_param_constant3] |
| 248 | + # %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0] |
| 249 | + # %_tensor_constant1 : [#users=1] = get_attr[target=_tensor_constant1] |
| 250 | + # %aten__native_batch_norm_legit_no_training_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten__to_copy_default_1, %_param_constant2, %_param_constant3, %_tensor_constant0, %_tensor_constant1, 0.1, 1e-05), kwargs = {}) |
| 251 | + # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default, 0), kwargs = {}) |
| 252 | + # %aten_hardtanh_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem, 0, 6), kwargs = {}) |
| 253 | + # %aten__to_copy_default_2 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_hardtanh_default,), kwargs = {memory_format: torch.channels_last}) |
| 254 | + # %aten_mean_dim : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mean.dim](args = (%aten__to_copy_default_2, [-1, -2], True), kwargs = {}) |
| 255 | + # %aten__to_copy_default_3 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_mean_dim,), kwargs = {memory_format: torch.contiguous_format}) |
| 256 | + # return [aten__to_copy_default_3] |
| 257 | + # ( |
| 258 | + # Tester( |
| 259 | + # self.Conv2dBnHardtanhMeanSequenceModule().eval(), |
| 260 | + # (torch.randn(1, 1, 6, 6),), |
| 261 | + # ) |
| 262 | + # .export() |
| 263 | + # .to_edge() |
| 264 | + # .run_passes(self.PassStage) |
| 265 | + # .check_count( |
| 266 | + # { |
| 267 | + # self.to_copy_name: 4, |
| 268 | + # } |
| 269 | + # ) |
| 270 | + # .run_method_and_compare_outputs() |
| 271 | + # ) |
232 | 272 |
|
233 | 273 | class Conv2dDynamicQuant(torch.nn.Module): |
234 | 274 | def __init__(self): |
|
0 commit comments