|
14 | 14 |
|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
| 17 | +import sys |
17 | 18 | from collections.abc import Iterator
|
18 | 19 | import contextlib
|
19 | 20 | import dataclasses
|
|
25 | 26 | import threading
|
26 | 27 | import types
|
27 | 28 | from typing import NamedTuple
|
| 29 | +from .ir import Location |
28 | 30 |
|
29 | 31 | # import jax.version
|
30 | 32 | # from jax._src.lib import xla_client
|
31 | 33 |
|
32 | 34 | from . import traceback_util
|
| 35 | +from .traceback_util import ( |
| 36 | + TracebackCaches, |
| 37 | + Traceback, |
| 38 | + _traceback_caches, |
| 39 | + _traceback_in_locations_limit, |
| 40 | + _include_full_tracebacks_in_locations, |
| 41 | +) |
33 | 42 |
|
34 | 43 | traceback_util.register_exclusion(__file__)
|
35 | 44 |
|
36 |
| -from ._mlir_libs._mlir import Traceback |
37 |
| - |
38 | 45 |
|
39 | 46 | class Frame(NamedTuple):
|
40 | 47 | file_name: str
|
41 | 48 | function_name: str
|
42 | 49 | start_line: int
|
43 |
| - start_column: int |
44 |
| - end_line: int |
45 |
| - end_column: int |
| 50 | + start_column: int | None |
| 51 | + end_line: int | None |
| 52 | + end_column: int | None |
46 | 53 |
|
47 | 54 |
|
48 | 55 | _exclude_paths: list[str] = [
|
@@ -173,16 +180,29 @@ def is_user_filename(filename: str) -> bool:
|
173 | 180 |
|
174 | 181 |
|
175 | 182 | def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
|
176 |
| - loc = Traceback.code_addr2location(code, lasti) |
177 |
| - start_line, start_column, end_line, end_column = loc |
178 |
| - return Frame( |
179 |
| - file_name=code.co_filename, |
180 |
| - function_name=code.co_qualname, |
181 |
| - start_line=start_line, |
182 |
| - start_column=start_column, |
183 |
| - end_line=end_line, |
184 |
| - end_column=end_column, |
185 |
| - ) |
| 183 | + if sys.version_info.minor >= 11: |
| 184 | + loc = Traceback.code_addr2location(code, lasti) |
| 185 | + start_line, start_column, end_line, end_column = loc |
| 186 | + frame = Frame( |
| 187 | + file_name=code.co_filename, |
| 188 | + function_name=code.co_qualname, |
| 189 | + start_line=start_line, |
| 190 | + start_column=start_column, |
| 191 | + end_line=end_line, |
| 192 | + end_column=end_column, |
| 193 | + ) |
| 194 | + else: |
| 195 | + start_line = Traceback.code_addr2line(code, lasti) |
| 196 | + frame = Frame( |
| 197 | + file_name=code.co_filename, |
| 198 | + function_name=code.co_qualname, |
| 199 | + start_line=start_line, |
| 200 | + start_column=None, |
| 201 | + end_line=None, |
| 202 | + end_column=None, |
| 203 | + ) |
| 204 | + |
| 205 | + return frame |
186 | 206 |
|
187 | 207 |
|
188 | 208 | def user_frames(traceback: Traceback | None) -> Iterator[Frame]:
|
@@ -365,3 +385,100 @@ def __exit__(self, exc_type, exc_value, traceback):
|
365 | 385 |
|
366 | 386 |
|
367 | 387 | transform_name_stack = TransformNameStackContextManager
|
| 388 | + |
| 389 | + |
| 390 | +def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str: |
| 391 | + canonical_file_name = caches.canonical_name_cache.get(file_name, None) |
| 392 | + if canonical_file_name is not None: |
| 393 | + return canonical_file_name |
| 394 | + |
| 395 | + # pattern = config.hlo_source_file_canonicalization_regex.value |
| 396 | + # if pattern: |
| 397 | + # file_name = re.sub(pattern, "", file_name) |
| 398 | + caches.canonical_name_cache[file_name] = file_name |
| 399 | + return file_name |
| 400 | + |
| 401 | + |
| 402 | +def _is_user_file(file_name: str) -> bool: |
| 403 | + is_user = _traceback_caches.is_user_file_cache.get(file_name, None) |
| 404 | + if is_user is not None: |
| 405 | + return is_user |
| 406 | + out = is_user_filename(file_name) |
| 407 | + _traceback_caches.is_user_file_cache[file_name] = out |
| 408 | + return out |
| 409 | + |
| 410 | + |
| 411 | +def _traceback_to_location(tb: Traceback) -> Location: |
| 412 | + """Converts a full traceback to a callsite() MLIR location.""" |
| 413 | + loc = _traceback_caches.traceback_cache.get(tb, None) |
| 414 | + if loc is not None: |
| 415 | + return loc |
| 416 | + |
| 417 | + frame_locs = [] |
| 418 | + frames_limit = _traceback_in_locations_limit |
| 419 | + frames_limit = frames_limit if frames_limit >= 0 else 1000 |
| 420 | + |
| 421 | + codes, lastis = tb.raw_frames() |
| 422 | + for i, code in enumerate(codes): |
| 423 | + if not _is_user_file(code.co_filename): |
| 424 | + continue |
| 425 | + |
| 426 | + lasti = lastis[i] |
| 427 | + code_lasti = code, lasti |
| 428 | + loc = _traceback_caches.location_cache.get(code_lasti, None) |
| 429 | + if loc is None: |
| 430 | + frame = raw_frame_to_frame(code, lasti) |
| 431 | + file_loc = Location.file( |
| 432 | + get_canonical_source_file(frame.file_name, _traceback_caches), |
| 433 | + frame.start_line, |
| 434 | + frame.start_column, |
| 435 | + frame.end_line, |
| 436 | + frame.end_column, |
| 437 | + ) |
| 438 | + loc = Location.name(frame.function_name, childLoc=file_loc) |
| 439 | + _traceback_caches.location_cache[code_lasti] = loc |
| 440 | + frame_locs.append(loc) |
| 441 | + if len(frame_locs) >= frames_limit: |
| 442 | + break |
| 443 | + |
| 444 | + n = len(frame_locs) |
| 445 | + if n == 0: |
| 446 | + loc = Location.unknown() |
| 447 | + elif n == 1: |
| 448 | + loc = frame_locs[0] |
| 449 | + else: |
| 450 | + loc = Location.callsite(frame_locs[0], frame_locs[1:]) |
| 451 | + _traceback_caches.traceback_cache[tb] = loc |
| 452 | + return loc |
| 453 | + |
| 454 | + |
| 455 | +def source_info_to_location( |
| 456 | + primitive: None, |
| 457 | + name_stack: NameStack, |
| 458 | + traceback: Traceback | None, |
| 459 | +) -> Location: |
| 460 | + if _include_full_tracebacks_in_locations: |
| 461 | + if traceback is None: |
| 462 | + loc = Location.unknown() |
| 463 | + else: |
| 464 | + loc = _traceback_to_location(traceback) |
| 465 | + else: |
| 466 | + frame = user_frame(traceback) |
| 467 | + if frame is None: |
| 468 | + loc = Location.unknown() |
| 469 | + else: |
| 470 | + loc = Location.file( |
| 471 | + get_canonical_source_file(frame.file_name, _traceback_caches), |
| 472 | + frame.start_line, |
| 473 | + frame.start_column, |
| 474 | + ) |
| 475 | + if primitive is None: |
| 476 | + if name_stack.stack: |
| 477 | + loc = Location.name(str(name_stack), childLoc=loc) |
| 478 | + else: |
| 479 | + eqn_str = ( |
| 480 | + f"{name_stack}/{primitive.name}" if name_stack.stack else primitive.name |
| 481 | + ) |
| 482 | + loc = Location.name(eqn_str, childLoc=loc) |
| 483 | + loc = Location.name(f"{primitive.name}:", childLoc=loc) |
| 484 | + return loc |
0 commit comments