Skip to content

Commit 7a5a911

Browse files
committed
refactor
1 parent 02f11a9 commit 7a5a911

File tree

4 files changed

+200
-22
lines changed

4 files changed

+200
-22
lines changed

mlir/lib/Bindings/Python/Traceback.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -440,22 +440,22 @@ void BuildTracebackSubmodule(nb::module_ &m) {
440440
},
441441
"Python wrapper around the Python C API function PyCode_Addr2Line");
442442

443+
#if PY_VERSION_HEX >= 0x030b00f0
443444
type.attr("code_addr2location") = nb::cpp_function(
444445
[](nb::handle code, int lasti) {
445446
if (!PyCode_Check(code.ptr())) {
446447
throw std::runtime_error("code argument must be a code object");
447448
}
448449
int start_line, start_column, end_line, end_column;
449-
// if (!PyCode_Addr2Location(reinterpret_cast<PyCodeObject
450-
// *>(code.ptr()),
451-
// lasti, &start_line, &start_column,
452-
// &end_line, &end_column)) {
453-
// throw nb::python_error();
454-
// }
455-
throw nb::python_error();
450+
if (!PyCode_Addr2Location(reinterpret_cast<PyCodeObject *>(code.ptr()),
451+
lasti, &start_line, &start_column, &end_line,
452+
&end_column)) {
453+
throw nb::python_error();
454+
}
456455
return nb::make_tuple(start_line, start_column, end_line, end_column);
457456
},
458457
"Python wrapper around the Python C API function PyCode_Addr2Location");
458+
#endif
459459
}
460460
} // namespace mlir::python
461461

mlir/python/mlir/source_info_util.py

Lines changed: 132 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import sys
1718
from collections.abc import Iterator
1819
import contextlib
1920
import dataclasses
@@ -25,24 +26,30 @@
2526
import threading
2627
import types
2728
from typing import NamedTuple
29+
from .ir import Location
2830

2931
# import jax.version
3032
# from jax._src.lib import xla_client
3133

3234
from . import traceback_util
35+
from .traceback_util import (
36+
TracebackCaches,
37+
Traceback,
38+
_traceback_caches,
39+
_traceback_in_locations_limit,
40+
_include_full_tracebacks_in_locations,
41+
)
3342

3443
traceback_util.register_exclusion(__file__)
3544

36-
from ._mlir_libs._mlir import Traceback
37-
3845

3946
class Frame(NamedTuple):
4047
file_name: str
4148
function_name: str
4249
start_line: int
43-
start_column: int
44-
end_line: int
45-
end_column: int
50+
start_column: int | None
51+
end_line: int | None
52+
end_column: int | None
4653

4754

4855
_exclude_paths: list[str] = [
@@ -173,16 +180,29 @@ def is_user_filename(filename: str) -> bool:
173180

174181

175182
def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
176-
loc = Traceback.code_addr2location(code, lasti)
177-
start_line, start_column, end_line, end_column = loc
178-
return Frame(
179-
file_name=code.co_filename,
180-
function_name=code.co_qualname,
181-
start_line=start_line,
182-
start_column=start_column,
183-
end_line=end_line,
184-
end_column=end_column,
185-
)
183+
if sys.version_info.minor >= 11:
184+
loc = Traceback.code_addr2location(code, lasti)
185+
start_line, start_column, end_line, end_column = loc
186+
frame = Frame(
187+
file_name=code.co_filename,
188+
function_name=code.co_qualname,
189+
start_line=start_line,
190+
start_column=start_column,
191+
end_line=end_line,
192+
end_column=end_column,
193+
)
194+
else:
195+
start_line = Traceback.code_addr2line(code, lasti)
196+
frame = Frame(
197+
file_name=code.co_filename,
198+
function_name=code.co_qualname,
199+
start_line=start_line,
200+
start_column=None,
201+
end_line=None,
202+
end_column=None,
203+
)
204+
205+
return frame
186206

187207

188208
def user_frames(traceback: Traceback | None) -> Iterator[Frame]:
@@ -365,3 +385,100 @@ def __exit__(self, exc_type, exc_value, traceback):
365385

366386

367387
transform_name_stack = TransformNameStackContextManager
388+
389+
390+
def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str:
391+
canonical_file_name = caches.canonical_name_cache.get(file_name, None)
392+
if canonical_file_name is not None:
393+
return canonical_file_name
394+
395+
# pattern = config.hlo_source_file_canonicalization_regex.value
396+
# if pattern:
397+
# file_name = re.sub(pattern, "", file_name)
398+
caches.canonical_name_cache[file_name] = file_name
399+
return file_name
400+
401+
402+
def _is_user_file(file_name: str) -> bool:
403+
is_user = _traceback_caches.is_user_file_cache.get(file_name, None)
404+
if is_user is not None:
405+
return is_user
406+
out = is_user_filename(file_name)
407+
_traceback_caches.is_user_file_cache[file_name] = out
408+
return out
409+
410+
411+
def _traceback_to_location(tb: Traceback) -> Location:
412+
"""Converts a full traceback to a callsite() MLIR location."""
413+
loc = _traceback_caches.traceback_cache.get(tb, None)
414+
if loc is not None:
415+
return loc
416+
417+
frame_locs = []
418+
frames_limit = _traceback_in_locations_limit
419+
frames_limit = frames_limit if frames_limit >= 0 else 1000
420+
421+
codes, lastis = tb.raw_frames()
422+
for i, code in enumerate(codes):
423+
if not _is_user_file(code.co_filename):
424+
continue
425+
426+
lasti = lastis[i]
427+
code_lasti = code, lasti
428+
loc = _traceback_caches.location_cache.get(code_lasti, None)
429+
if loc is None:
430+
frame = raw_frame_to_frame(code, lasti)
431+
file_loc = Location.file(
432+
get_canonical_source_file(frame.file_name, _traceback_caches),
433+
frame.start_line,
434+
frame.start_column,
435+
frame.end_line,
436+
frame.end_column,
437+
)
438+
loc = Location.name(frame.function_name, childLoc=file_loc)
439+
_traceback_caches.location_cache[code_lasti] = loc
440+
frame_locs.append(loc)
441+
if len(frame_locs) >= frames_limit:
442+
break
443+
444+
n = len(frame_locs)
445+
if n == 0:
446+
loc = Location.unknown()
447+
elif n == 1:
448+
loc = frame_locs[0]
449+
else:
450+
loc = Location.callsite(frame_locs[0], frame_locs[1:])
451+
_traceback_caches.traceback_cache[tb] = loc
452+
return loc
453+
454+
455+
def source_info_to_location(
456+
primitive: None,
457+
name_stack: NameStack,
458+
traceback: Traceback | None,
459+
) -> Location:
460+
if _include_full_tracebacks_in_locations:
461+
if traceback is None:
462+
loc = Location.unknown()
463+
else:
464+
loc = _traceback_to_location(traceback)
465+
else:
466+
frame = user_frame(traceback)
467+
if frame is None:
468+
loc = Location.unknown()
469+
else:
470+
loc = Location.file(
471+
get_canonical_source_file(frame.file_name, _traceback_caches),
472+
frame.start_line,
473+
frame.start_column,
474+
)
475+
if primitive is None:
476+
if name_stack.stack:
477+
loc = Location.name(str(name_stack), childLoc=loc)
478+
else:
479+
eqn_str = (
480+
f"{name_stack}/{primitive.name}" if name_stack.stack else primitive.name
481+
)
482+
loc = Location.name(eqn_str, childLoc=loc)
483+
loc = Location.name(f"{primitive.name}:", childLoc=loc)
484+
return loc

mlir/python/mlir/traceback_util.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414

1515
from __future__ import annotations
1616

17+
import dataclasses
1718
from collections.abc import Callable
1819
import functools
1920
import os
2021
import traceback
2122
import types
2223
from typing import Any, TypeVar, cast
24+
from ._mlir_libs._mlir import Traceback
25+
from .ir import Location
2326

2427

2528
C = TypeVar("C", bound=Callable[..., Any])
@@ -236,3 +239,22 @@ def reraise_with_filtered_traceback(*args, **kwargs):
236239
del mode
237240

238241
return cast(C, reraise_with_filtered_traceback)
242+
243+
244+
@dataclasses.dataclass
245+
class TracebackCaches:
246+
traceback_cache: dict[Traceback, Location]
247+
location_cache: dict[tuple[types.CodeType, int], Location]
248+
canonical_name_cache: dict[str, str]
249+
is_user_file_cache: dict[str, bool]
250+
251+
def __init__(self):
252+
self.traceback_cache = {}
253+
self.location_cache = {}
254+
self.canonical_name_cache = {}
255+
self.is_user_file_cache = {}
256+
257+
258+
_traceback_caches = TracebackCaches()
259+
_traceback_in_locations_limit = 100
260+
_include_full_tracebacks_in_locations = True

mlir/test/python/ir/line_info.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
import gc
3+
import traceback
4+
5+
from mlir import source_info_util
6+
from mlir.source_info_util import _traceback_to_location
7+
from mlir import traceback_util
8+
from mlir.ir import Context
9+
10+
# CHECK: hello
11+
print("hello")
12+
13+
14+
# traceback_util.register_exclusion(__file__)
15+
16+
17+
def run(f):
18+
print("\nTEST:", f.__name__)
19+
with Context() as ctx:
20+
f()
21+
gc.collect()
22+
# assert Context._get_live_count() == 0
23+
return f
24+
25+
26+
@run
27+
def foo():
28+
def bar():
29+
curr = source_info_util.current()
30+
print(curr.name_stack)
31+
print(curr.traceback)
32+
traceback.print_tb(
33+
traceback_util.filter_traceback(curr.traceback.as_python_traceback())
34+
)
35+
36+
loc = _traceback_to_location(curr.traceback)
37+
print(loc)
38+
39+
bar()

0 commit comments

Comments
 (0)