Skip to content

Commit f10e124

Browse files
committed
refactor
1 parent 02f11a9 commit f10e124

File tree

4 files changed

+212
-24
lines changed

4 files changed

+212
-24
lines changed

mlir/lib/Bindings/Python/Traceback.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -440,22 +440,22 @@ void BuildTracebackSubmodule(nb::module_ &m) {
440440
},
441441
"Python wrapper around the Python C API function PyCode_Addr2Line");
442442

443+
#if PY_VERSION_HEX >= 0x030b00f0
443444
type.attr("code_addr2location") = nb::cpp_function(
444445
[](nb::handle code, int lasti) {
445446
if (!PyCode_Check(code.ptr())) {
446447
throw std::runtime_error("code argument must be a code object");
447448
}
448449
int start_line, start_column, end_line, end_column;
449-
// if (!PyCode_Addr2Location(reinterpret_cast<PyCodeObject
450-
// *>(code.ptr()),
451-
// lasti, &start_line, &start_column,
452-
// &end_line, &end_column)) {
453-
// throw nb::python_error();
454-
// }
455-
throw nb::python_error();
450+
if (!PyCode_Addr2Location(reinterpret_cast<PyCodeObject *>(code.ptr()),
451+
lasti, &start_line, &start_column, &end_line,
452+
&end_column)) {
453+
throw nb::python_error();
454+
}
456455
return nb::make_tuple(start_line, start_column, end_line, end_column);
457456
},
458457
"Python wrapper around the Python C API function PyCode_Addr2Location");
458+
#endif
459459
}
460460
} // namespace mlir::python
461461

mlir/python/mlir/source_info_util.py

Lines changed: 144 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import sys
1718
from collections.abc import Iterator
1819
import contextlib
1920
import dataclasses
@@ -24,25 +25,32 @@
2425
import sysconfig
2526
import threading
2627
import types
27-
from typing import NamedTuple
28+
from typing import NamedTuple, Optional
29+
from .ir import Location
2830

2931
# import jax.version
3032
# from jax._src.lib import xla_client
3133

3234
from . import traceback_util
35+
from .traceback_util import (
36+
TracebackCaches,
37+
Traceback,
38+
_traceback_caches,
39+
_traceback_in_locations_limit,
40+
_include_full_tracebacks_in_locations,
41+
)
3342

3443
traceback_util.register_exclusion(__file__)
3544

36-
from ._mlir_libs._mlir import Traceback
37-
3845

39-
class Frame(NamedTuple):
46+
@dataclasses.dataclass(frozen=True)
47+
class Frame:
4048
file_name: str
4149
function_name: str
4250
start_line: int
43-
start_column: int
44-
end_line: int
45-
end_column: int
51+
start_column: Optional[int] = None
52+
end_line: Optional[int] = None
53+
end_column: Optional[int] = None
4654

4755

4856
_exclude_paths: list[str] = [
@@ -173,16 +181,27 @@ def is_user_filename(filename: str) -> bool:
173181

174182

175183
def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
176-
loc = Traceback.code_addr2location(code, lasti)
177-
start_line, start_column, end_line, end_column = loc
178-
return Frame(
179-
file_name=code.co_filename,
180-
function_name=code.co_qualname,
181-
start_line=start_line,
182-
start_column=start_column,
183-
end_line=end_line,
184-
end_column=end_column,
185-
)
184+
if sys.version_info.minor >= 11:
185+
loc = Traceback.code_addr2location(code, lasti)
186+
start_line, start_column, end_line, end_column = loc
187+
frame = Frame(
188+
file_name=code.co_filename,
189+
function_name=code.co_qualname,
190+
start_line=start_line,
191+
start_column=start_column,
192+
end_line=end_line,
193+
end_column=end_column,
194+
)
195+
else:
196+
start_line = Traceback.code_addr2line(code, lasti)
197+
frame = Frame(
198+
file_name=code.co_filename,
199+
function_name=code.co_name,
200+
start_line=start_line,
201+
start_column=0,
202+
)
203+
204+
return frame
186205

187206

188207
def user_frames(traceback: Traceback | None) -> Iterator[Frame]:
@@ -365,3 +384,111 @@ def __exit__(self, exc_type, exc_value, traceback):
365384

366385

367386
transform_name_stack = TransformNameStackContextManager
387+
388+
389+
def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str:
390+
canonical_file_name = caches.canonical_name_cache.get(file_name, None)
391+
if canonical_file_name is not None:
392+
return canonical_file_name
393+
394+
# pattern = config.hlo_source_file_canonicalization_regex.value
395+
# if pattern:
396+
# file_name = re.sub(pattern, "", file_name)
397+
caches.canonical_name_cache[file_name] = file_name
398+
return file_name
399+
400+
401+
def _is_user_file(file_name: str) -> bool:
402+
is_user = _traceback_caches.is_user_file_cache.get(file_name, None)
403+
if is_user is not None:
404+
return is_user
405+
out = is_user_filename(file_name)
406+
_traceback_caches.is_user_file_cache[file_name] = out
407+
return out
408+
409+
410+
def _traceback_to_location(tb: Traceback) -> Location:
411+
"""Converts a full traceback to a callsite() MLIR location."""
412+
loc = _traceback_caches.traceback_cache.get(tb, None)
413+
if loc is not None:
414+
return loc
415+
416+
frame_locs = []
417+
frames_limit = _traceback_in_locations_limit
418+
frames_limit = frames_limit if frames_limit >= 0 else 1000
419+
420+
codes, lastis = tb.raw_frames()
421+
for i, code in enumerate(codes):
422+
if not _is_user_file(code.co_filename):
423+
continue
424+
425+
lasti = lastis[i]
426+
code_lasti = code, lasti
427+
loc = _traceback_caches.location_cache.get(code_lasti, None)
428+
if loc is None:
429+
frame = raw_frame_to_frame(code, lasti)
430+
if (
431+
frame.start_column is not None
432+
and frame.end_line is not None
433+
and frame.end_column is not None
434+
):
435+
file_loc = Location.file(
436+
get_canonical_source_file(frame.file_name, _traceback_caches),
437+
frame.start_line,
438+
frame.start_column,
439+
frame.end_line,
440+
frame.end_column,
441+
)
442+
else:
443+
file_loc = Location.file(
444+
get_canonical_source_file(frame.file_name, _traceback_caches),
445+
frame.start_line,
446+
frame.start_column,
447+
)
448+
loc = Location.name(frame.function_name, childLoc=file_loc)
449+
_traceback_caches.location_cache[code_lasti] = loc
450+
frame_locs.append(loc)
451+
if len(frame_locs) >= frames_limit:
452+
break
453+
454+
n = len(frame_locs)
455+
if n == 0:
456+
loc = Location.unknown()
457+
elif n == 1:
458+
loc = frame_locs[0]
459+
else:
460+
loc = Location.callsite(frame_locs[0], frame_locs[1:])
461+
_traceback_caches.traceback_cache[tb] = loc
462+
return loc
463+
464+
465+
def source_info_to_location(
466+
primitive: None,
467+
name_stack: NameStack,
468+
traceback: Traceback | None,
469+
) -> Location:
470+
if _include_full_tracebacks_in_locations:
471+
if traceback is None:
472+
loc = Location.unknown()
473+
else:
474+
loc = _traceback_to_location(traceback)
475+
else:
476+
frame = user_frame(traceback)
477+
if frame is None:
478+
loc = Location.unknown()
479+
else:
480+
loc = Location.file(
481+
get_canonical_source_file(frame.file_name, _traceback_caches),
482+
frame.start_line,
483+
frame.start_column,
484+
)
485+
if primitive is None:
486+
if name_stack.stack:
487+
loc = Location.name(str(name_stack), childLoc=loc)
488+
else:
489+
eqn_str = (
490+
f"{name_stack}/{primitive.name}" if name_stack.stack else primitive.name
491+
)
492+
loc = Location.name(eqn_str, childLoc=loc)
493+
loc = Location.name(f"{primitive.name}:", childLoc=loc)
494+
return loc

mlir/python/mlir/traceback_util.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414

1515
from __future__ import annotations
1616

17+
import dataclasses
1718
from collections.abc import Callable
1819
import functools
1920
import os
2021
import traceback
2122
import types
2223
from typing import Any, TypeVar, cast
24+
from ._mlir_libs._mlir import Traceback
25+
from .ir import Location
2326

2427

2528
C = TypeVar("C", bound=Callable[..., Any])
@@ -236,3 +239,22 @@ def reraise_with_filtered_traceback(*args, **kwargs):
236239
del mode
237240

238241
return cast(C, reraise_with_filtered_traceback)
242+
243+
244+
@dataclasses.dataclass
245+
class TracebackCaches:
246+
traceback_cache: dict[Traceback, Location]
247+
location_cache: dict[tuple[types.CodeType, int], Location]
248+
canonical_name_cache: dict[str, str]
249+
is_user_file_cache: dict[str, bool]
250+
251+
def __init__(self):
252+
self.traceback_cache = {}
253+
self.location_cache = {}
254+
self.canonical_name_cache = {}
255+
self.is_user_file_cache = {}
256+
257+
258+
_traceback_caches = TracebackCaches()
259+
_traceback_in_locations_limit = 100
260+
_include_full_tracebacks_in_locations = True

mlir/test/python/ir/line_info.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
import gc
3+
import traceback
4+
5+
from mlir import source_info_util
6+
from mlir.source_info_util import _traceback_to_location
7+
from mlir import traceback_util
8+
from mlir.ir import Context
9+
10+
# CHECK: hello
11+
print("hello")
12+
13+
14+
# traceback_util.register_exclusion(__file__)
15+
16+
17+
def run(f):
18+
print("\nTEST:", f.__name__)
19+
with Context() as ctx:
20+
f()
21+
gc.collect()
22+
# assert Context._get_live_count() == 0
23+
return f
24+
25+
26+
@run
27+
def foo():
28+
def bar():
29+
curr = source_info_util.current()
30+
print(curr.name_stack)
31+
print(curr.traceback)
32+
traceback.print_tb(
33+
traceback_util.filter_traceback(curr.traceback.as_python_traceback())
34+
)
35+
36+
loc = _traceback_to_location(curr.traceback)
37+
print(loc)
38+
39+
bar()

0 commit comments

Comments
 (0)