|
12 | 12 | FuseConstantArgsPass,
|
13 | 13 | )
|
14 | 14 | from executorch.backends.arm.test import common
|
15 |
| -from executorch.backends.arm.test.tester.test_pipeline import PassPipeline |
| 15 | +from executorch.backends.arm.test.tester.test_pipeline import ( |
| 16 | + PassPipeline, |
| 17 | + TosaPipelineFP, |
| 18 | + TosaPipelineINT, |
| 19 | +) |
16 | 20 |
|
17 | 21 | input_t = Tuple[torch.Tensor] # Input x
|
18 | 22 | input_t2 = Tuple[torch.Tensor, torch.Tensor]
|
@@ -103,6 +107,22 @@ def forward(self, a, b):
|
103 | 107 | return torch.cat((a, b), dim=0)
|
104 | 108 |
|
105 | 109 |
|
| 110 | +class LinearConst(torch.nn.Module): |
| 111 | + """A linear layer that can be computed AOT""" |
| 112 | + |
| 113 | + def __init__(self, in_out_features: int = 3, bias: bool = True): |
| 114 | + super().__init__() |
| 115 | + self.linear = torch.nn.Linear(in_out_features, in_out_features, bias=bias) |
| 116 | + self.example_input = torch.rand(in_out_features, in_out_features) |
| 117 | + |
| 118 | + def forward(self, x: torch.Tensor): |
| 119 | + y = torch.full_like(x, 1.0) |
| 120 | + return self.linear(y) + x |
| 121 | + |
| 122 | + def get_example_input(self): |
| 123 | + return self.example_input |
| 124 | + |
| 125 | + |
106 | 126 | modules = {
|
107 | 127 | "fuse_parameter": FuseParameter(),
|
108 | 128 | "fuse_buffer": FuseBuffer(),
|
@@ -152,3 +172,30 @@ def test_fuse_const_ops_tosa_BI_cat(module: torch.nn.Module):
|
152 | 172 | passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass],
|
153 | 173 | )
|
154 | 174 | pipeline.run()
|
| 175 | + |
| 176 | + |
| 177 | +def test_linear_const_tosa_FP(): |
| 178 | + model = LinearConst() |
| 179 | + example_input = model.get_example_input() |
| 180 | + pipeline = TosaPipelineFP[input_t]( |
| 181 | + model, |
| 182 | + (example_input,), |
| 183 | + aten_op=[], |
| 184 | + exir_op=[], |
| 185 | + use_to_edge_transform_and_lower=True, |
| 186 | + ) |
| 187 | + pipeline.run() |
| 188 | + |
| 189 | + |
| 190 | +def test_linear_const_tosa_INT(): |
| 191 | + model = LinearConst() |
| 192 | + example_input = model.get_example_input() |
| 193 | + pipeline = TosaPipelineINT[input_t]( |
| 194 | + model, |
| 195 | + (example_input,), |
| 196 | + aten_op=[], |
| 197 | + exir_op=[], |
| 198 | + per_channel_quantization=False, |
| 199 | + use_to_edge_transform_and_lower=True, |
| 200 | + ) |
| 201 | + pipeline.run() |
0 commit comments