|
7 | 7 | import unittest |
8 | 8 |
|
9 | 9 | import torch |
| 10 | +from backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass |
10 | 11 | from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import ( |
11 | 12 | ChannelsLastTaggedReshapePass, |
12 | 13 | ) |
@@ -58,6 +59,88 @@ def test_fp32_channels_last_tagged_reshape_pass(self): |
58 | 59 | .run_method_and_compare_outputs() |
59 | 60 | ) |
60 | 61 |
|
| 62 | + class LinearConv(torch.nn.Module): |
| 63 | + def __init__(self): |
| 64 | + super().__init__() |
| 65 | + self.conv1 = torch.nn.Conv2d(3, 3, 3) |
| 66 | + self.linear1 = torch.nn.Linear(4, 3) |
| 67 | + |
| 68 | + def forward(self, x): |
| 69 | + y = self.linear1(x) |
| 70 | + return self.conv1(y) |
| 71 | + |
| 72 | + def test_conv_linear_dim_order_swaps_on_nhwc_input(self): |
| 73 | + tester = Tester( |
| 74 | + self.LinearConv().eval(), |
| 75 | + (torch.randn(1, 3, 6, 4).to(memory_format=torch.channels_last),), |
| 76 | + ) |
| 77 | + |
| 78 | + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() |
| 79 | + |
| 80 | + def test_conv_linear_dim_order_swaps_on_nchw_input(self): |
| 81 | + tester = Tester( |
| 82 | + self.LinearConv().eval(), |
| 83 | + (torch.randn(1, 3, 6, 4),), |
| 84 | + ) |
| 85 | + |
| 86 | + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() |
| 87 | + |
| 88 | + class ConvLinearConv(torch.nn.Module): |
| 89 | + def __init__(self): |
| 90 | + super().__init__() |
| 91 | + self.conv1 = torch.nn.Conv2d(3, 3, 3) |
| 92 | + self.linear1 = torch.nn.Linear(4, 4) |
| 93 | + |
| 94 | + def forward(self, x): |
| 95 | + y = self.conv1(x) |
| 96 | + return self.linear1(y) |
| 97 | + |
| 98 | + def test_linear_conv_dim_order_swaps_on_nhwc_input(self): |
| 99 | + tester = Tester( |
| 100 | + self.ConvLinearConv().eval(), |
| 101 | + (torch.randn(1, 3, 6, 6).to(memory_format=torch.channels_last),), |
| 102 | + ) |
| 103 | + |
| 104 | + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() |
| 105 | + |
| 106 | + def test_linear_conv_dim_order_swaps_on_nchw_input(self): |
| 107 | + tester = Tester( |
| 108 | + self.ConvLinearConv().eval(), |
| 109 | + (torch.randn(1, 3, 6, 6),), |
| 110 | + ) |
| 111 | + |
| 112 | + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() |
| 113 | + |
| 114 | + class Bilinear(torch.nn.Module): |
| 115 | + def __init__(self): |
| 116 | + super().__init__() |
| 117 | + |
| 118 | + def forward(self, x): |
| 119 | + return torch.nn.functional.interpolate( |
| 120 | + x, scale_factor=2, mode="bilinear", align_corners=True |
| 121 | + ) |
| 122 | + |
| 123 | + def test_nhwc_input_on_nhwc_op(self): |
| 124 | + tester = Tester( |
| 125 | + self.Bilinear().eval(), |
| 126 | + ( |
| 127 | + torch.arange(8) |
| 128 | + .reshape(1, 2, 2, 2) |
| 129 | + .to(torch.float32) |
| 130 | + .to(memory_format=torch.channels_last), |
| 131 | + ), |
| 132 | + ) |
| 133 | + |
| 134 | + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() |
| 135 | + |
| 136 | + def test_nchw_input_on_nhwc_op(self): |
| 137 | + tester = Tester( |
| 138 | + self.Bilinear().eval(), |
| 139 | + (torch.arange(8).reshape(1, 2, 2, 2).to(torch.float32),), |
| 140 | + ) |
| 141 | + |
| 142 | + tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs() |
| 143 | + |
61 | 144 | def test_qs8_channels_last_tagged_reshape_pass(self): |
62 | 145 | for module, num_reshape in self.modules.items(): |
63 | 146 | ( |
|
0 commit comments