Skip to content

Commit c09680f

Browse files
authored
[FIX] [SANITIZER] Report correct line numbers for OOB errors in loops (#251)
1 parent 3a20bcc commit c09680f

File tree

3 files changed

+156
-9
lines changed

3 files changed

+156
-9
lines changed

tests/end_to_end/test_sanitizer.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,12 @@ def _check_range_satisfiable(
131131
access_addr: Z3Expr,
132132
expr_constraints: Optional[BoolRef],
133133
symbolic_expr: SymbolicExpr,
134+
source_location: Optional[tuple[str, int, str]] = None,
134135
) -> None:
135136
self.check_inside_loop.append((access_addr, expr_constraints))
136-
super()._check_range_satisfiable(access_addr, expr_constraints, symbolic_expr)
137+
super()._check_range_satisfiable(
138+
access_addr, expr_constraints, symbolic_expr, source_location
139+
)
137140

138141
def register_for_loop_callback(self) -> ForLoopCallbacks:
139142
callbacks = super().register_for_loop_callback()
@@ -296,3 +299,64 @@ def test_loop_deferred_checks_simplify():
296299
loop_deferred_check_simplify_kernel[(2,)](out)
297300

298301
assert load_index_checker.observed_offsets == []
302+
303+
304+
# ======== Line Number Reporting Tests ===========
305+
306+
307+
# Create a dedicated sanitizer for line number tests
308+
line_number_checker: SymbolicSanitizer = SymbolicSanitizer(abort_on_error=False)
309+
310+
311+
@triton_viz.trace(client=line_number_checker)
312+
@triton.jit
313+
def oob_in_loop_kernel(ptr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr):
314+
"""Kernel where OOB occurs inside a loop at the tl.load line."""
315+
pid = tl.program_id(0)
316+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
317+
for _ in range(2):
318+
# This line should be reported in the traceback when OOB occurs
319+
val = tl.load(ptr + offsets + 1000) # OOB access due to +1000 offset
320+
tl.store(ptr + offsets, val)
321+
322+
323+
def test_loop_oob_reports_correct_line_number():
324+
"""
325+
Test that sanitizer reports the correct line number for OOB errors in loops.
326+
327+
Previously, the sanitizer would report the function definition line instead
328+
of the actual tl.load/tl.store line that caused the error. This was because
329+
traceback was captured at loop exit time rather than when the memory
330+
operation was executed.
331+
"""
332+
line_number_checker.records.clear()
333+
334+
data = torch.zeros((16,), dtype=torch.float32)
335+
336+
oob_in_loop_kernel[(1,)](data, N=16, BLOCK_SIZE=16)
337+
338+
# Verify that OOB was detected
339+
assert len(line_number_checker.records) > 0, "Expected OOB to be detected"
340+
341+
record = line_number_checker.records[0]
342+
assert len(record.user_code_tracebacks) > 0, "Expected traceback info"
343+
344+
tb_info = record.user_code_tracebacks[0]
345+
346+
# The error should point to the tl.load line, not the function definition
347+
assert "tl.load" in tb_info.line_of_code, (
348+
f"Expected traceback to point to tl.load line, "
349+
f"but got: {tb_info.line_of_code!r}"
350+
)
351+
352+
# Verify the line contains the OOB offset
353+
assert "+1000" in tb_info.line_of_code or "1000" in tb_info.line_of_code, (
354+
f"Expected line to contain the OOB offset, "
355+
f"but got: {tb_info.line_of_code!r}"
356+
)
357+
358+
# Verify function name
359+
assert tb_info.func_name == "oob_in_loop_kernel", (
360+
f"Expected func_name to be 'oob_in_loop_kernel', "
361+
f"but got: {tb_info.func_name!r}"
362+
)

triton_viz/clients/sanitizer/report.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import linecache
2+
import sys
13
import traceback
2-
from typing import Optional, TYPE_CHECKING
4+
from pathlib import Path
5+
from typing import TYPE_CHECKING
36

47
import numpy as np
58

@@ -15,6 +18,58 @@
1518
from .sanitizer import SymbolicExpr
1619

1720

21+
# Paths that identify framework code (not user code)
22+
_FRAMEWORK_PATHS = [
23+
"triton_viz/core/",
24+
"triton_viz/clients/",
25+
"triton/runtime/",
26+
"triton/language/",
27+
"site-packages/triton/",
28+
]
29+
30+
31+
def _get_user_code_location() -> tuple[str, int, str] | None:
32+
from types import FrameType
33+
34+
frame: FrameType | None = sys._getframe()
35+
36+
while frame is not None:
37+
filename = Path(frame.f_code.co_filename).as_posix()
38+
39+
# Skip Python internals
40+
if filename.startswith("<"):
41+
frame = frame.f_back
42+
continue
43+
44+
# Check if this is user code (not in framework paths)
45+
is_framework = any(path in filename for path in _FRAMEWORK_PATHS)
46+
47+
# User code: not in framework paths, OR in examples directory
48+
if not is_framework or "examples/" in filename:
49+
return (
50+
frame.f_code.co_filename,
51+
frame.f_lineno,
52+
frame.f_code.co_name,
53+
)
54+
55+
frame = frame.f_back
56+
57+
return None
58+
59+
60+
def _location_to_traceback_info(
61+
source_location: tuple[str, int, str],
62+
) -> TracebackInfo:
63+
filename, lineno, func_name = source_location
64+
line_of_code = linecache.getline(filename, lineno).rstrip()
65+
return TracebackInfo(
66+
filename=filename,
67+
lineno=lineno,
68+
func_name=func_name,
69+
line_of_code=line_of_code,
70+
)
71+
72+
1873
def print_oob_record(oob_record: OutOfBoundsRecord, max_display=10):
1974
"""
2075
Print detailed logs for a given OOB record.
@@ -95,7 +150,7 @@ def print_oob_record(oob_record: OutOfBoundsRecord, max_display=10):
95150

96151

97152
def print_oob_record_pdb_style(
98-
oob_record: OutOfBoundsRecord, symbolic_expr: Optional["SymbolicExpr"] = None
153+
oob_record: OutOfBoundsRecord, symbolic_expr: "SymbolicExpr | None" = None
99154
):
100155
"""
101156
Print a comprehensive diagnostic report for OOB errors in PDB-style format.
@@ -104,7 +159,7 @@ def print_oob_record_pdb_style(
104159
----------
105160
oob_record : OutOfBoundsRecord
106161
The record containing information about out-of-bounds accesses.
107-
symbolic_expr : Optional[SymbolicExpr]
162+
symbolic_expr : SymbolicExpr | None
108163
The symbolic expression tree that led to the OOB access.
109164
"""
110165
from pathlib import Path

triton_viz/clients/sanitizer/sanitizer.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,13 @@
8484
check_inner_stride_equal_to_one,
8585
)
8686
from .data import OutOfBoundsRecordZ3
87-
from .report import _get_traceback_info, print_oob_record, print_oob_record_pdb_style
87+
from .report import (
88+
_get_traceback_info,
89+
_get_user_code_location,
90+
_location_to_traceback_info,
91+
print_oob_record,
92+
print_oob_record_pdb_style,
93+
)
8894
from ...core.config import config as cfg
8995

9096

@@ -163,6 +169,9 @@ class PendingCheck:
163169
symbolic_expr: "SymbolicExpr"
164170
addr_expr: Z3Expr
165171
constraints: ConstraintConjunction
172+
# Lightweight source location: (filename, lineno, func_name)
173+
# Captured immediately to preserve accurate line info for deferred checks
174+
source_location: Optional[tuple[str, int, str]] = None
166175

167176

168177
@dataclass
@@ -1707,6 +1716,7 @@ def _check_range_satisfiable(
17071716
access_addr: Z3Expr,
17081717
expr_constraints: ConstraintConjunction,
17091718
symbolic_expr: SymbolicExpr,
1719+
source_location: Optional[tuple[str, int, str]] = None,
17101720
) -> None:
17111721
# Use push/pop on persistent solver
17121722
solver = self.solver
@@ -1739,8 +1749,10 @@ def _check_single_addr(addr_expr: Z3Expr) -> None:
17391749
else:
17401750
op_type = Load
17411751

1742-
# Report with symbolic expression
1743-
self._report(op_type, tensor, violation_addr, symbolic_expr)
1752+
# Report with symbolic expression and source location
1753+
self._report(
1754+
op_type, tensor, violation_addr, symbolic_expr, source_location
1755+
)
17441756
solver.pop()
17451757

17461758
if isinstance(access_addr, list):
@@ -1767,11 +1779,16 @@ def _handle_access_check(self, expr: SymbolicExpr) -> None:
17671779
signature = _make_signature(z3_addr, z3_constraints)
17681780
pending_idx = ctx.signature_cache.get(signature)
17691781
if pending_idx is None:
1782+
# Capture source location now while we're still in the user's tl.load/tl.store call.
1783+
# This is a lightweight operation that only traverses frame objects.
1784+
# The actual source line will be read later only if an error is detected.
1785+
source_location = _get_user_code_location()
17701786
ctx.signature_cache[signature] = len(ctx.pending_checks)
17711787
pending_check = PendingCheck(
17721788
symbolic_expr=expr,
17731789
addr_expr=z3_addr,
17741790
constraints=z3_constraints,
1791+
source_location=source_location,
17751792
)
17761793
ctx.pending_checks.append(pending_check)
17771794
else:
@@ -1784,8 +1801,15 @@ def _report(
17841801
tensor: Tensor,
17851802
violation_address: int,
17861803
symbolic_expr: Optional[SymbolicExpr] = None,
1804+
source_location: Optional[tuple[str, int, str]] = None,
17871805
) -> None:
1788-
traceback_info = _get_traceback_info()
1806+
# Use pre-captured location if available (for deferred checks in loops),
1807+
# otherwise capture it now (for immediate checks outside loops)
1808+
if source_location is not None:
1809+
traceback_info = [_location_to_traceback_info(source_location)]
1810+
else:
1811+
traceback_info = _get_traceback_info()
1812+
17891813
tensor_name = self._get_tensor_name(tensor)
17901814
oob_record = OutOfBoundsRecordZ3(
17911815
op_type=op_type,
@@ -2285,7 +2309,7 @@ def loop_hook_after(lineno: int) -> None:
22852309
if not self.loop_stack or self.loop_stack[-1].lineno != lineno:
22862310
return
22872311
ctx = self.loop_stack.pop()
2288-
# execute pending checks
2312+
# Execute pending checks that were deferred during loop execution
22892313
solver = self.solver
22902314
addr_sym = self.addr_sym
22912315
assert solver is not None
@@ -2302,6 +2326,9 @@ def loop_hook_after(lineno: int) -> None:
23022326
addr_expr = pending_check.addr_expr
23032327
expr_constraints = pending_check.constraints
23042328
symbolic_expr = pending_check.symbolic_expr
2329+
# Use the source location captured when the check was created,
2330+
# not the current location (which would be the loop exit point)
2331+
source_location = pending_check.source_location
23052332

23062333
if cfg.verbose:
23072334
print(
@@ -2315,6 +2342,7 @@ def loop_hook_after(lineno: int) -> None:
23152342
addr_expr,
23162343
expr_constraints,
23172344
symbolic_expr,
2345+
source_location,
23182346
)
23192347
if ctx.pending_checks:
23202348
solver.pop()

0 commit comments

Comments
 (0)