|
91 | 91 | ReplaceScalarWithTensorArgPassTOSABI, |
92 | 92 | ReplaceScalarWithTensorArgPassTOSAMI, |
93 | 93 | RetraceFoldedDtypesPass, |
| 94 | + RewriteUpsamplePass, |
94 | 95 | ScalarsToAttributePass, |
95 | 96 | SizeAdjustInputPass, |
96 | 97 | ToTosaMemoryFormatPass, |
|
112 | 113 | from executorch.exir.pass_manager import PassManager |
113 | 114 | from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass |
114 | 115 | from torch.fx import GraphModule |
| 116 | +from torch.fx.passes.infra.pass_base import PassResult |
| 117 | +from torch.nn.modules import Module |
115 | 118 |
|
116 | 119 |
|
117 | 120 | class ArmPassManager(PassManager): |
@@ -204,6 +207,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: |
204 | 207 | # needs to happen before AddBiasPass, but after the table ops are inserted |
205 | 208 | # to be able to validate that conv2d has right dtype arguments. |
206 | 209 | self.add_pass(DecomposeConv2dWithInt16ActivationPass()) |
| 210 | + self.add_pass(RewriteUpsamplePass(exported_program)) |
207 | 211 | self.add_pass(AddBiasPass(exported_program)) |
208 | 212 |
|
209 | 213 | self.add_pass(FuseEqualPlaceholdersPass(exported_program)) |
@@ -288,6 +292,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: |
288 | 292 | self.add_pass(FuseViewCopyTransform()) |
289 | 293 | self.add_pass(FuseConstantArgsPass(exported_program)) |
290 | 294 | self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) |
| 295 | + self.add_pass(RewriteUpsamplePass(exported_program)) |
291 | 296 | self.add_pass(AddBiasPass(exported_program)) |
292 | 297 | self.add_pass(InsertTableOpsPass(exported_program)) |
293 | 298 | self.add_pass(FuseEqualPlaceholdersPass(exported_program)) |
@@ -355,3 +360,20 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): |
355 | 360 | self.add_pass(DecomposeMaskedFill()) |
356 | 361 |
|
357 | 362 | return self._transform(graph_module) |
| 363 | + |
| 364 | + def __call__(self, module: Module) -> PassResult: |
| 365 | + try: |
| 366 | + return super().__call__(module) |
| 367 | + except Exception as e: |
| 368 | + first_exception = e.__cause__ or e.__context__ or e |
| 369 | + import re |
| 370 | + |
| 371 | + message = e.args[0] |
| 372 | + m = re.search(r"An error occurred when running the '([^']+)' pass", message) |
| 373 | + if m: |
| 374 | + pass_name = m.group(1) |
| 375 | + first_exception.args = ( |
| 376 | + f"{pass_name}: {first_exception.args[0]}", |
| 377 | + *first_exception.args[1:], |
| 378 | + ) |
| 379 | + raise first_exception |
0 commit comments