|
14 | 14 | # ============================================================================== |
15 | 15 |
|
16 | 16 | import logging |
17 | | -import os |
18 | 17 | from typing import Any, Literal, Optional, Union |
19 | 18 |
|
| 19 | +import ai_edge_torch |
20 | 20 | from ai_edge_torch import fx_pass_base |
21 | 21 | from ai_edge_torch import lowertools |
22 | 22 | from ai_edge_torch import model |
|
26 | 26 | from ai_edge_torch.quantize import quant_config as qcfg |
27 | 27 | import torch |
28 | 28 |
|
29 | | -os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1" |
30 | | - |
31 | 29 |
|
32 | 30 | def _run_convert_passes( |
33 | 31 | exported_program: torch.export.ExportedProgram, |
34 | 32 | ) -> torch.export.ExportedProgram: |
35 | 33 | exported_program = generative_fx_passes.run_generative_passes( |
36 | 34 | exported_program |
37 | 35 | ) |
38 | | - exported_program = fx_pass_base.run_passes( |
39 | | - exported_program, |
40 | | - [ |
41 | | - fx_passes.BuildInterpolateCompositePass(), |
42 | | - fx_passes.CanonicalizePass(), |
43 | | - fx_passes.OptimizeLayoutTransposesPass(), |
44 | | - fx_passes.CanonicalizePass(), |
45 | | - fx_passes.BuildAtenCompositePass(), |
46 | | - fx_passes.CanonicalizePass(), |
47 | | - fx_passes.RemoveNonUserOutputsPass(), |
48 | | - fx_passes.CanonicalizePass(), |
49 | | - fx_passes.InjectMlirDebuginfoPass(), |
50 | | - fx_passes.CanonicalizePass(), |
51 | | - ], |
52 | | - ) |
| 36 | + |
| 37 | + passes = [ |
| 38 | + fx_passes.BuildInterpolateCompositePass(), |
| 39 | + fx_passes.CanonicalizePass(), |
| 40 | + fx_passes.OptimizeLayoutTransposesPass(), |
| 41 | + fx_passes.CanonicalizePass(), |
| 42 | + fx_passes.BuildAtenCompositePass(), |
| 43 | + fx_passes.CanonicalizePass(), |
| 44 | + fx_passes.RemoveNonUserOutputsPass(), |
| 45 | + fx_passes.CanonicalizePass(), |
| 46 | + ] |
| 47 | + |
| 48 | + # Debuginfo is not injected automatically by odml_torch. Only inject |
| 49 | + # debuginfo via fx pass when using torch_xla. |
| 50 | + if ai_edge_torch.config.use_torch_xla: |
| 51 | + passes += [ |
| 52 | + fx_passes.InjectMlirDebuginfoPass(), |
| 53 | + fx_passes.CanonicalizePass(), |
| 54 | + ] |
| 55 | + |
| 56 | + exported_program = fx_pass_base.run_passes(exported_program, passes) |
53 | 57 | return exported_program |
54 | 58 |
|
55 | 59 |
|
|
0 commit comments