Skip to content

Commit 53390dd

Browse files
yijie-yangcopybara-github
authored andcommitted
Add fx node name info to ODML Torch lowerings.
PiperOrigin-RevId: 725602099
1 parent 48d225a commit 53390dd

File tree

4 files changed

+55
-7
lines changed

4 files changed

+55
-7
lines changed

ai_edge_torch/odml_torch/debuginfo/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from ._build import build_mlir_debuginfo, build_mlir_file_debuginfo
15+
"""Debug info generation for ODML Torch."""
16+
17+
from . import _build
1618
from ._op_polyfill import write_mlir_debuginfo_op
19+
20+
build_nodename_debuginfo = _build.build_nodename_debuginfo
21+
build_mlir_file_debuginfo = _build.build_mlir_file_debuginfo
22+
build_mlir_debuginfo = _build.build_mlir_debuginfo

ai_edge_torch/odml_torch/debuginfo/_build.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,16 @@ def _get_canonical_filename(filename):
4040
4141
This should be factored out so that pattern is a global option that a user
4242
can override.
43+
44+
Args:
45+
filename: The filename to canonicalize.
46+
47+
Returns:
48+
The canonicalized filename.
4349
"""
4450

45-
# TODO: We should add a config option to provide a regex to strip from the
46-
# debug info. Currently absolute path is used.
51+
# TODO(yijieyang): We should add a config option to provide a regex to strip
52+
# from the debug info. Currently absolute path is used.
4753
return filename
4854

4955

@@ -55,9 +61,23 @@ def build_mlir_file_debuginfo(node: torch.fx.Node):
5561

5662
# Note: This uses internal APIs and may break in the future.
5763
pt_trace = torch.fx.graph._parse_stack_trace(node.stack_trace)
64+
if pt_trace is None:
65+
return None, None
5866
return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
5967

6068

69+
def build_nodename_debuginfo(node: torch.fx.Node):
70+
"""Build the fx node name for the given node's lowerings in MLIR."""
71+
history = node.meta.get("from_node", [])
72+
if not history:
73+
return None
74+
if len(history) > 1:
75+
return history[1][0]
76+
if hasattr(history[0], "name"): # torch 2.6.0+
77+
return history[0].name
78+
return None
79+
80+
6181
def build_mlir_debuginfo(node: torch.fx.Node):
6282
"""Build the debuginfo string for the given node's lowerings in MLIR."""
6383

ai_edge_torch/odml_torch/export.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,21 @@ def __init__(self, module: torch.fx.GraphModule, lctx: LoweringContext):
8888
self.outputs = None
8989

9090
def _build_loc(self, node: torch.fx.Node):
91+
"""Build MLIR location for the given node.
92+
93+
The location contains:
94+
- layer info
95+
- fx node name
96+
- file and line info
97+
98+
Currently it's still under development and format is subject to change.
99+
100+
Args:
101+
node: The torch.fx.Node to build the location for.
102+
103+
Returns:
104+
The MLIR location for the given node.
105+
"""
91106

92107
info = debuginfo.build_mlir_debuginfo(node)
93108
if info is None:
@@ -98,7 +113,12 @@ def _build_loc(self, node: torch.fx.Node):
98113
if file is not None:
99114
fileinfo = ir.Location.file(filename=file, line=line, col=0)
100115

101-
return ir.Location.name(name=info, childLoc=fileinfo)
116+
node_name = debuginfo.build_nodename_debuginfo(node)
117+
nodeinfo = None
118+
if node_name is not None:
119+
nodeinfo = ir.Location.name(name=node_name, childLoc=fileinfo)
120+
121+
return ir.Location.name(name=info, childLoc=nodeinfo)
102122

103123
def run_node(self, node: torch.fx.Node):
104124
loc = self._build_loc(node)

ai_edge_torch/odml_torch/test/test_tf_integration.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def test_resnet18(self):
5858

5959
# Check value and debug info.
6060
self.assertTrue(np.allclose(lowering_output, torch_output, atol=1e-5))
61-
self.assertIn(
62-
"torchvision/models/resnet.py", lowered.get_text(enable_debug_info=True)
63-
)
61+
lowered_text = lowered.get_text(enable_debug_info=True)
62+
# Check the file info.
63+
self.assertIn("torchvision/models/resnet.py", lowered_text)
64+
# Check the fx node name.
65+
self.assertIn("relu_1", lowered_text)
6466

6567

6668
if __name__ == "__main__":

0 commit comments

Comments
 (0)