Skip to content

Commit 729690b

Browse files
GleasonKcopybara-github
authored andcommitted
Add file-line info to ODML Torch lowerings.
PiperOrigin-RevId: 724070541
1 parent 45f18d2 commit 729690b

File tree

4 files changed

+35
-2
lines changed

4 files changed

+35
-2
lines changed

ai_edge_torch/odml_torch/debuginfo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from ._build import build_mlir_debuginfo
15+
from ._build import build_mlir_debuginfo, build_mlir_file_debuginfo
1616
from ._op_polyfill import write_mlir_debuginfo_op

ai_edge_torch/odml_torch/debuginfo/_build.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
import torch
16+
import re
1617

1718

1819
def _class_fullname(cls):
@@ -34,6 +35,29 @@ def _get_hierarchy(node: torch.fx.Node):
3435
return hierachy_str
3536

3637

38+
def _get_canonical_filename(filename):
39+
"""Remove unnecessary path prefix to make the filename more readable.
40+
41+
This should be factored out so that pattern is a global option that a user
42+
can override.
43+
"""
44+
45+
# TODO: We should add a config option to provide a regex to strip from the
46+
# debug info. Currently absolute path is used.
47+
return filename
48+
49+
50+
def build_mlir_file_debuginfo(node: torch.fx.Node):
51+
"""Build the file and line info for the given node's lowerings in MLIR."""
52+
53+
if not node.stack_trace:
54+
return None, None
55+
56+
# Note: This uses internal APIs and may break in the future.
57+
pt_trace = torch.fx.graph._parse_stack_trace(node.stack_trace)
58+
return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
59+
60+
3761
def build_mlir_debuginfo(node: torch.fx.Node):
3862
"""Build the debuginfo string for the given node's lowerings in MLIR."""
3963

ai_edge_torch/odml_torch/export.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,12 @@ def _build_loc(self, node: torch.fx.Node):
9393
if info is None:
9494
return ir.Location.unknown()
9595

96-
return ir.Location.name(name=info)
96+
(file, line) = debuginfo.build_mlir_file_debuginfo(node)
97+
fileinfo = None
98+
if file is not None:
99+
fileinfo = ir.Location.file(filename=file, line=line, col=0)
100+
101+
return ir.Location.name(name=info, childLoc=fileinfo)
97102

98103
def run_node(self, node: torch.fx.Node):
99104
loc = self._build_loc(node)

ai_edge_torch/odml_torch/test/test_tf_integration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ def test_resnet18(self):
5656
torch_output = model(*args).detach().numpy()
5757
lowering_output = np.array(lowered(*args))
5858

59+
# Check value and debug info.
5960
self.assertTrue(np.allclose(lowering_output, torch_output, atol=1e-5))
61+
self.assertIn(
62+
"torchvision/models/resnet.py", lowered.get_text(enable_debug_info=True)
63+
)
6064

6165

6266
if __name__ == "__main__":

0 commit comments

Comments
 (0)