File tree Expand file tree Collapse file tree 4 files changed +35
-2
lines changed Expand file tree Collapse file tree 4 files changed +35
-2
lines changed Original file line number Diff line number Diff line change 1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414# ==============================================================================
15- from ._build import build_mlir_debuginfo
15+ from ._build import build_mlir_debuginfo , build_mlir_file_debuginfo
1616from ._op_polyfill import write_mlir_debuginfo_op
Original file line number Diff line number Diff line change 1313# limitations under the License.
1414# ==============================================================================
1515import torch
16+ import re
1617
1718
1819def _class_fullname (cls ):
@@ -34,6 +35,29 @@ def _get_hierarchy(node: torch.fx.Node):
3435 return hierachy_str
3536
3637
38+ def _get_canonical_filename (filename ):
39+ """Remove unnecessary path prefix to make the filename more readable.
40+
41+ This should be factored out so that pattern is a global option that a user
42+ can override.
43+ """
44+
45+ # TODO: We should add a config option to provide a regex to strip from the
46+ # debug info. Currently absolute path is used.
47+ return filename
48+
49+
50+ def build_mlir_file_debuginfo (node : torch .fx .Node ):
51+ """Build the file and line info for the given node's lowerings in MLIR."""
52+
53+ if not node .stack_trace :
54+ return None , None
55+
56+ # Note: This uses internal APIs and may break in the future.
57+ pt_trace = torch .fx .graph ._parse_stack_trace (node .stack_trace )
58+ return _get_canonical_filename (pt_trace .file ), int (pt_trace .lineno )
59+
60+
3761def build_mlir_debuginfo (node : torch .fx .Node ):
3862 """Build the debuginfo string for the given node's lowerings in MLIR."""
3963
Original file line number Diff line number Diff line change @@ -93,7 +93,12 @@ def _build_loc(self, node: torch.fx.Node):
9393 if info is None :
9494 return ir .Location .unknown ()
9595
96- return ir .Location .name (name = info )
96+ (file , line ) = debuginfo .build_mlir_file_debuginfo (node )
97+ fileinfo = None
98+ if file is not None :
99+ fileinfo = ir .Location .file (filename = file , line = line , col = 0 )
100+
101+ return ir .Location .name (name = info , childLoc = fileinfo )
97102
98103 def run_node (self , node : torch .fx .Node ):
99104 loc = self ._build_loc (node )
Original file line number Diff line number Diff line change @@ -56,7 +56,11 @@ def test_resnet18(self):
5656 torch_output = model (* args ).detach ().numpy ()
5757 lowering_output = np .array (lowered (* args ))
5858
59+ # Check value and debug info.
5960 self .assertTrue (np .allclose (lowering_output , torch_output , atol = 1e-5 ))
61+ self .assertIn (
62+ "torchvision/models/resnet.py" , lowered .get_text (enable_debug_info = True )
63+ )
6064
6165
6266if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments