Skip to content

Commit 7e68dce

Browse files
chunnienccopybara-github
authored andcommitted
Deprecate InjectMlirDebuginfoPass for odml_torch default migration
PiperOrigin-RevId: 711867888
1 parent 7407150 commit 7e68dce

File tree

4 files changed

+26
-99
lines changed

4 files changed

+26
-99
lines changed

ai_edge_torch/_convert/conversion.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
# ==============================================================================
1515

1616
import logging
17-
import os
1817
from typing import Any, Literal, Optional, Union
1918

19+
import ai_edge_torch
2020
from ai_edge_torch import fx_pass_base
2121
from ai_edge_torch import lowertools
2222
from ai_edge_torch import model
@@ -26,30 +26,34 @@
2626
from ai_edge_torch.quantize import quant_config as qcfg
2727
import torch
2828

29-
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
30-
3129

3230
def _run_convert_passes(
3331
exported_program: torch.export.ExportedProgram,
3432
) -> torch.export.ExportedProgram:
3533
exported_program = generative_fx_passes.run_generative_passes(
3634
exported_program
3735
)
38-
exported_program = fx_pass_base.run_passes(
39-
exported_program,
40-
[
41-
fx_passes.BuildInterpolateCompositePass(),
42-
fx_passes.CanonicalizePass(),
43-
fx_passes.OptimizeLayoutTransposesPass(),
44-
fx_passes.CanonicalizePass(),
45-
fx_passes.BuildAtenCompositePass(),
46-
fx_passes.CanonicalizePass(),
47-
fx_passes.RemoveNonUserOutputsPass(),
48-
fx_passes.CanonicalizePass(),
49-
fx_passes.InjectMlirDebuginfoPass(),
50-
fx_passes.CanonicalizePass(),
51-
],
52-
)
36+
37+
passes = [
38+
fx_passes.BuildInterpolateCompositePass(),
39+
fx_passes.CanonicalizePass(),
40+
fx_passes.OptimizeLayoutTransposesPass(),
41+
fx_passes.CanonicalizePass(),
42+
fx_passes.BuildAtenCompositePass(),
43+
fx_passes.CanonicalizePass(),
44+
fx_passes.RemoveNonUserOutputsPass(),
45+
fx_passes.CanonicalizePass(),
46+
]
47+
48+
# Debuginfo is not injected automatically by odml_torch. Only inject
49+
# debuginfo via fx pass when using torch_xla.
50+
if ai_edge_torch.config.use_torch_xla:
51+
passes += [
52+
fx_passes.InjectMlirDebuginfoPass(),
53+
fx_passes.CanonicalizePass(),
54+
]
55+
56+
exported_program = fx_pass_base.run_passes(exported_program, passes)
5357
return exported_program
5458

5559

ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def debuginfo_writer(*args, **kwargs):
6262

6363

6464
class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
65+
"""DEPRECATED: Debuginfo is injected automatically by odml_torch."""
6566

6667
def call(self, graph_module: torch.fx.GraphModule):
6768
for node in graph_module.graph.nodes:

ai_edge_torch/_convert/fx_passes/test/test_inject_mlir_debuginfo_pass.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

ai_edge_torch/lowertools/torch_xla_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
# https://github.com/google-ai-edge/ai-edge-torch/issues/326
2828
os.environ["PJRT_DEVICE"] = "CPU"
2929

30+
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
31+
32+
3033
from ai_edge_torch import model
3134
from ai_edge_torch._convert import conversion_utils
3235
from ai_edge_torch._convert import signature as signature_module

0 commit comments

Comments
 (0)