|
26 | 26 | ForceChannelLastForConvPass,
|
27 | 27 | MakeSliceAndCatDimOutermostPass,
|
28 | 28 | ReplaceAddMMWithLinearPass,
|
| 29 | + ReplaceAtenApproxGeluWithApproxGeluPass, |
29 | 30 | ReplaceAtenConvolutionWithJarvisConvolutionPass,
|
30 | 31 | ReplaceConstantPadNdWithSlicePass,
|
31 | 32 | ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
|
32 | 33 | ReplaceConvWithIm2RowAndLinear,
|
33 | 34 | ReplaceEmptyTensorsWithFullPass,
|
34 | 35 | ReplaceFunctionallyEquivalentOpTargets,
|
35 |
| - ReplaceGeluWithApproximateGeluPass, |
36 | 36 | ReplaceIm2RowWithViewPass,
|
37 | 37 | ReplaceLinearWithFullyConnectedOpPass,
|
38 | 38 | ReplaceMatmulWithTransposedMatmulPass,
|
@@ -1287,17 +1287,41 @@ def forward(self, cond: torch.Tensor):
|
1287 | 1287 | 1,
|
1288 | 1288 | )
|
1289 | 1289 |
|
1290 |
| - def test_replace_aten_gelu_with_approximate_gelu(self): |
1291 |
| - class Gelu(torch.nn.Module): |
1292 |
| - def forward(self, input): |
1293 |
| - return torch.nn.functional.gelu(input) |
| 1290 | + def test_no_replace_aten_gelu_with_approximate_gelu(self): |
| 1291 | + inputs = torch.randn(2, 1, 64) |
| 1292 | + |
| 1293 | + gm = single_op_builder( |
| 1294 | + placeholders=(inputs,), |
| 1295 | + op=exir_ops.edge.aten.gelu.default, |
| 1296 | + args=(inputs,), |
| 1297 | + ) |
| 1298 | + gm = ExportPass().call(gm).graph_module |
| 1299 | + |
| 1300 | + p = ReplaceAtenApproxGeluWithApproxGeluPass() |
| 1301 | + graph_after_passes = p.call(gm).graph_module |
1294 | 1302 |
|
| 1303 | + # Assert that aten.gelu op was not decomposed, since it didn't have an approximate argument |
| 1304 | + self.assertEqual( |
| 1305 | + count_node( |
| 1306 | + graph_after_passes, |
| 1307 | + exir_ops.edge.aten.gelu.default, |
| 1308 | + ), |
| 1309 | + 1, |
| 1310 | + ) |
| 1311 | + |
| 1312 | + def test_replace_aten_approximate_gelu_with_approximate_gelu(self): |
1295 | 1313 | inputs = torch.randn(2, 1, 64)
|
1296 | 1314 |
|
1297 |
| - graph_module = export_to_edge(Gelu(), (inputs,)).exported_program().graph_module |
| 1315 | + gm = single_op_builder( |
| 1316 | + placeholders=(inputs,), |
| 1317 | + op=exir_ops.edge.aten.gelu.default, |
| 1318 | + args=(inputs,), |
| 1319 | + kwargs={"approximate": "tanh"}, |
| 1320 | + ) |
| 1321 | + gm = ExportPass().call(gm).graph_module |
1298 | 1322 |
|
1299 |
| - p = ReplaceGeluWithApproximateGeluPass() |
1300 |
| - graph_after_passes = cast(PassResult, p(graph_module)).graph_module |
| 1323 | + p = ReplaceAtenApproxGeluWithApproxGeluPass() |
| 1324 | + graph_after_passes = p.call(gm).graph_module |
1301 | 1325 |
|
1302 | 1326 | # Assert that aten.gelu op was decomposed
|
1303 | 1327 | self.assertEqual(
|
|
0 commit comments