|
18 | 18 | ComputeConstantOpsAOT, |
19 | 19 | Conv1dUnsqueezePass, |
20 | 20 | ConvertAnyDefaultDimDimsPass, |
| 21 | + ConvertELUParamsPass, |
21 | 22 | ConvertExpandCopyToRepeatPass, |
22 | 23 | ConvertFullLikeToFullPass, |
| 24 | + ConvertInt64ConstOpsToInt32Pass, |
| 25 | + ConvertInt64OutputOpsToInt32Pass, |
23 | 26 | ConvertIntPowToMuls, |
24 | 27 | ConvertMinMaxPass, |
25 | 28 | ConvertMmToBmmPass, |
|
39 | 42 | DecomposeCosineSimilarityPass, |
40 | 43 | DecomposeCumsumPass, |
41 | 44 | DecomposeDivPass, |
| 45 | + DecomposeEluPass, |
42 | 46 | DecomposeEmbeddingPass, |
43 | 47 | DecomposeExpm1Pass, |
44 | 48 | DecomposeGeluPass, |
|
98 | 102 | from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass |
99 | 103 | from executorch.exir import ExportedProgram |
100 | 104 | from executorch.exir.pass_manager import PassManager |
| 105 | +from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass |
101 | 106 | from torch.fx import GraphModule |
102 | 107 |
|
103 | 108 |
|
@@ -132,6 +137,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: |
132 | 137 | self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) |
133 | 138 | self.add_pass(AnnotateDecomposedMatmulPass()) |
134 | 139 | self.add_pass(QuantizeOperatorArguments()) |
| 140 | + self.add_pass(ConvertELUParamsPass()) |
135 | 141 | self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] |
136 | 142 | self.add_pass(RetraceFoldedDtypesPass()) |
137 | 143 | self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) |
@@ -180,6 +186,8 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: |
180 | 186 | self.add_pass(DecomposeAtanPass()) |
181 | 187 | self.add_pass(DecomposeAtanhPass()) |
182 | 188 | self.add_pass(DecomposeAddmmPass()) |
| 189 | + self.add_pass(DecomposeEluPass()) |
| 190 | + self.add_pass(DecomposeExpm1Pass()) |
183 | 191 | self.add_pass(ConvertIntPowToMuls()) |
184 | 192 | self.add_pass(CastBoolToInt8Pass()) |
185 | 193 | self.add_pass(DecomposeSinhPass()) |
@@ -258,6 +266,11 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram): |
258 | 266 | ) |
259 | 267 |
|
260 | 268 | def transform_for_annotation_pipeline(self, graph_module: GraphModule): |
| 269 | + self.add_pass( |
| 270 | + RemoveGraphAssertsPass() |
| 271 | + ) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph |
| 272 | + self.add_pass(ConvertInt64ConstOpsToInt32Pass()) |
| 273 | + self.add_pass(ConvertInt64OutputOpsToInt32Pass()) |
261 | 274 | self.add_pass(InsertCastForOpsWithInt64InputPass()) |
262 | 275 | self.add_pass(DecomposeEmbeddingPass()) |
263 | 276 | self.add_pass(DecomposeScaledDotProductAttention()) |
|
0 commit comments