Skip to content

Commit 7af2c7a

Browse files
yijie-yangcopybara-github
authored andcommitted
Correct fx node name in ODML torch lowering
PiperOrigin-RevId: 729584237
1 parent 2f96332 commit 7af2c7a

File tree

2 files changed

+51
-12
lines changed

2 files changed

+51
-12
lines changed

ai_edge_torch/odml_torch/debuginfo/_build.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
import torch
1615
import re
16+
import torch
1717

1818

1919
def _class_fullname(cls):
@@ -35,7 +35,7 @@ def _get_hierarchy(node: torch.fx.Node):
3535
return hierachy_str
3636

3737

38-
def _get_canonical_filename(filename):
38+
def _get_canonical_filename(filename: str):
3939
"""Remove unnecessary path prefix to make the filename more readable.
4040
4141
This should be factored out so that pattern is a global option that a user
@@ -53,6 +53,29 @@ def _get_canonical_filename(filename):
5353
return filename
5454

5555

56+
def _get_canoical_nodename(node: torch.fx.Node) -> str:
57+
"""Get the canonical node name from the node's history."""
58+
59+
history = node.meta.get("from_node", [])
60+
61+
if len(history) > 1: # Compatible with torch version under 2.6.0
62+
return history[1][0]
63+
64+
if not hasattr(history[0], "name"):
65+
return None
66+
names = []
67+
while history:
68+
names.append(history[0].name)
69+
history = history[0].from_node
70+
71+
# Based on the experiment, the third to last name in the history stack
72+
# can be mapped to the original torch node name. The history stack is
73+
# generated by tracing the node's transformation history during lowering.
74+
if len(names) > 2:
75+
return names[-3]
76+
return None
77+
78+
5679
def build_mlir_file_debuginfo(node: torch.fx.Node):
5780
"""Build the file and line info for the given node's lowerings in MLIR."""
5881

@@ -66,16 +89,13 @@ def build_mlir_file_debuginfo(node: torch.fx.Node):
6689
return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
6790

6891

69-
def build_nodename_debuginfo(node: torch.fx.Node):
92+
def build_nodename_debuginfo(node: torch.fx.Node) -> str:
7093
"""Build the fx node name for the given node's lowerings in MLIR."""
71-
history = node.meta.get("from_node", [])
72-
if not history:
94+
95+
if not hasattr(node, "meta") or "from_node" not in node.meta:
7396
return None
74-
if len(history) > 1:
75-
return history[1][0]
76-
if hasattr(history[0], "name"): # torch 2.6.0+
77-
return history[0].name
78-
return None
97+
98+
return _get_canoical_nodename(node)
7999

80100

81101
def build_mlir_debuginfo(node: torch.fx.Node):

ai_edge_torch/odml_torch/test/test_tf_integration.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def setUp(self):
2727
torch.manual_seed(0)
2828

2929
def test_mlir_lowered_call(self):
30+
"""Test a simple model with MLIR lowered call."""
31+
3032
class AddModel(torch.nn.Module):
3133

3234
def forward(self, x, y):
@@ -45,6 +47,8 @@ def forward(self, x, y):
4547
self.assertTrue(np.allclose(lowering_output, torch_output))
4648

4749
def test_resnet18(self):
50+
"""Test Resnet18 model with MLIR lowered call."""
51+
4852
model = torchvision.models.resnet18().eval()
4953
forward_args = lambda: (torch.rand((1, 3, 224, 224)),)
5054

@@ -58,11 +62,26 @@ def test_resnet18(self):
5862

5963
# Check value and debug info.
6064
self.assertTrue(np.allclose(lowering_output, torch_output, atol=1e-5))
65+
66+
def test_debuginfo_from_export_lower(self):
67+
"""Test the debuginfo with export lower."""
68+
69+
model = torchvision.models.resnet18().eval()
70+
forward_args = lambda: (torch.rand((1, 3, 224, 224)),)
71+
72+
ep = torch.export.export(model, forward_args())
73+
lowered = odml_torch.export.exported_program_to_mlir(ep)
74+
6175
lowered_text = lowered.get_text(enable_debug_info=True)
6276
# Check the file info.
6377
self.assertIn("torchvision/models/resnet.py", lowered_text)
64-
# Check the fx node name.
65-
self.assertIn("relu_1", lowered_text)
78+
# Check the fx node names.
79+
for n in ep.graph.nodes:
80+
# Record all aten op nodes from the original graph and check if they
81+
# are lowered to the same name in the lowered graph.
82+
if n.op == "call_function" and not n.name.startswith("getitem"):
83+
# Ensure strings like `loc("relu__1"` are present in the lowered text.
84+
self.assertIn(f'loc("{n.name}"', lowered_text)
6685

6786

6887
if __name__ == "__main__":

0 commit comments

Comments
 (0)