Skip to content

Commit 3f5944c

Browse files
yijie-yangcopybara-github
authored andcommitted
Clean up history stack before fx passes in ODML torch lowering
PiperOrigin-RevId: 731019293
1 parent 2c23904 commit 3f5944c

File tree

4 files changed

+93
-18
lines changed

4 files changed

+93
-18
lines changed

ai_edge_torch/_convert/conversion.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def export(**kwargs):
125125
else:
126126
exported_program = torch.export.export(**kwargs, strict=True)
127127

128+
exported_program = fx_infra.graph_utils.reset_from_node_meta(
129+
exported_program
130+
)
131+
128132
exported_program = fx_infra.safe_run_decompositions(
129133
exported_program,
130134
fx_infra.decomp.pre_convert_decomp(),

ai_edge_torch/fx_infra/graph_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""FX graph utilities."""
16+
17+
from packaging import version
1618
import torch
19+
from torch.fx import traceback
1720

1821

1922
def remove_dangling_args(graph_module: torch.fx.GraphModule):
@@ -40,3 +43,32 @@ def remove_assert_tensor_metadata_nodes(graph_module: torch.fx.GraphModule):
4043
graph_module.graph.lint()
4144
graph_module.recompile()
4245
return graph_module
46+
47+
48+
def is_torch_version_under(torch_version: str) -> bool:
49+
"""Checks if the current torch version is under the given version."""
50+
if not torch_version:
51+
raise ValueError("torch_version cannot be empty.")
52+
current_version = version.parse(torch.__version__)
53+
compared_version = version.parse(torch_version)
54+
return current_version < compared_version
55+
56+
57+
def reset_from_node_meta(ep: torch.export.ExportedProgram):
58+
"""Resets the "from_node" meta field to fx node name only for the exported program."""
59+
60+
for node in ep.graph.nodes:
61+
if not hasattr(node, "meta") or "from_node" not in node.meta:
62+
continue
63+
if is_torch_version_under("2.6.0.dev0"):
64+
# For torch version under 2.6.0, the history stack is a list of tuple. We
65+
# will only keep the current node's name in the history stack.
66+
history = [(node.name,)]
67+
else:
68+
# Clean up the history stack by keeping only the current node info (fx
69+
# node name and graph id) in a list of size 1. Clear the "from_node" field
70+
# to prevent redundant additions to the history stack.
71+
history = [traceback.NodeSource(node)]
72+
history[0].from_node = []
73+
node.meta["from_node"] = history
74+
return ep

ai_edge_torch/odml_torch/debuginfo/_build.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
import re
15+
from ai_edge_torch.fx_infra import graph_utils
1616
import torch
1717

1818

@@ -57,9 +57,13 @@ def _get_canoical_nodename(node: torch.fx.Node) -> str:
5757
"""Get the canonical node name from the node's history."""
5858

5959
history = node.meta.get("from_node", [])
60+
if not history:
61+
return None
6062

61-
if len(history) > 1: # Compatible with torch version under 2.6.0
62-
return history[1][0]
63+
# Compatible with torch version under 2.6.0. The history stack is a list of
64+
# tuple. The first element of the first tuple is the node name.
65+
if graph_utils.is_torch_version_under("2.6.0.dev0"):
66+
return history[0][0]
6367

6468
if not hasattr(history[0], "name"):
6569
return None
@@ -68,12 +72,10 @@ def _get_canoical_nodename(node: torch.fx.Node) -> str:
6872
names.append(history[0].name)
6973
history = history[0].from_node
7074

71-
# Based on the experiment, the third to last name in the history stack
72-
# can be mapped to the original torch node name. The history stack is
73-
# generated by tracing the node's transformation history during lowering.
74-
if len(names) > 2:
75-
return names[-3]
76-
return None
75+
# The history stack is generated by tracing the node's transformation history
76+
# during lowering. The last name in the history stack is used to map to the
77+
# original torch fx node name.
78+
return names[-1]
7779

7880

7981
def build_mlir_file_debuginfo(node: torch.fx.Node):

ai_edge_torch/odml_torch/test/test_tf_integration.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
from ai_edge_torch import fx_infra
1516
from ai_edge_torch import odml_torch
1617
import numpy as np
1718
import torch
@@ -20,6 +21,23 @@
2021
from absl.testing import absltest as googletest
2122

2223

24+
def _reset_from_node_meta_and_lower(ep: torch.export.ExportedProgram):
25+
"""Lower the exported program with canonical history stack."""
26+
ep = fx_infra.graph_utils.reset_from_node_meta(ep)
27+
return odml_torch.export.exported_program_to_mlir(ep)
28+
29+
30+
def _is_aten_op(node: torch.fx.Node) -> bool:
31+
return node.op == "call_function" and not node.name.startswith("getitem")
32+
33+
34+
class AddModel(torch.nn.Module):
35+
"""A simple model that does addition."""
36+
37+
def forward(self, x, y):
38+
return x + y + x + y
39+
40+
2341
class TensorflowIntegrationTest(googletest.TestCase):
2442

2543
def setUp(self):
@@ -29,11 +47,6 @@ def setUp(self):
2947
def test_mlir_lowered_call(self):
3048
"""Test a simple model with MLIR lowered call."""
3149

32-
class AddModel(torch.nn.Module):
33-
34-
def forward(self, x, y):
35-
return x + y + x + y
36-
3750
model = AddModel().eval()
3851
forward_args = lambda: (torch.rand((10, 10)), torch.rand((10, 10)))
3952
ep = torch.export.export(model, forward_args())
@@ -53,8 +66,7 @@ def test_resnet18(self):
5366
forward_args = lambda: (torch.rand((1, 3, 224, 224)),)
5467

5568
ep = torch.export.export(model, forward_args())
56-
57-
lowered = odml_torch.export.exported_program_to_mlir(ep)
69+
lowered = _reset_from_node_meta_and_lower(ep)
5870

5971
args = forward_args()
6072
torch_output = model(*args).detach().numpy()
@@ -70,7 +82,7 @@ def test_debuginfo_from_export_lower(self):
7082
forward_args = lambda: (torch.rand((1, 3, 224, 224)),)
7183

7284
ep = torch.export.export(model, forward_args())
73-
lowered = odml_torch.export.exported_program_to_mlir(ep)
85+
lowered = _reset_from_node_meta_and_lower(ep)
7486

7587
lowered_text = lowered.get_text(enable_debug_info=True)
7688
# Check the file info.
@@ -79,10 +91,35 @@ def test_debuginfo_from_export_lower(self):
7991
for n in ep.graph.nodes:
8092
# Record all aten op nodes from the original graph and check if they
8193
# are lowered to the same name in the lowered graph.
82-
if n.op == "call_function" and not n.name.startswith("getitem"):
94+
if _is_aten_op(n):
8395
# Ensure strings like `loc("relu__1"` are present in the lowered text.
8496
self.assertIn(f'loc("{n.name}"', lowered_text)
8597

98+
def test_debuginfo_from_loaded_reexport_lower(self):
99+
"""Test the debuginfo with loaded reexport lower."""
100+
101+
model = AddModel().eval()
102+
forward_args = lambda: (torch.rand((10, 10)), torch.rand((10, 10)))
103+
104+
# Ensure the debuginfo is preserved after saving, loading and reexporting.
105+
ep = torch.export.export(model, forward_args())
106+
torch.export.save(ep, "/tmp/add_model.pt2")
107+
loaded_ep = torch.export.load("/tmp/add_model.pt2")
108+
reexported_ep = torch.export.export(loaded_ep.module(), forward_args())
109+
lowered = _reset_from_node_meta_and_lower(reexported_ep)
110+
111+
lowered_text = lowered.get_text(enable_debug_info=True)
112+
# Check the file info.
113+
self.assertIn(
114+
"ai_edge_torch/odml_torch/test/test_tf_integration.py", lowered_text
115+
)
116+
# Check the fx node names.
117+
for n in reexported_ep.graph.nodes:
118+
# Record all aten op nodes from the original graph and check if they
119+
# are lowered to the same name in the lowered graph.
120+
if _is_aten_op(n):
121+
self.assertIn(f'loc("{n.name}"', lowered_text)
122+
86123

87124
if __name__ == "__main__":
88125
googletest.main()

0 commit comments

Comments
 (0)