Skip to content

Commit 0a70a66

Browse files
JokerenFindHao
authored andcommitted
[LINEINFO] Relocate IR when compile from IR source (triton-lang#6502)
1 parent 056a69b commit 0a70a66

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

python/test/unit/language/test_line_info.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,35 @@ def test_line_info_env(monkeypatch, status: str):
220220
kernel_info = kernel_single.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1, ))
221221
file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind])
222222
assert len(file_lines) == 0 if status == "1" else len(file_lines) > 0
223+
224+
225+
@pytest.mark.parametrize("status", ["ttir", ""])
226+
def test_line_info_ir_source(monkeypatch, status, tmp_path):
227+
try:
228+
obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format()
229+
except BaseException:
230+
pytest.skip("disassembler is not available")
231+
232+
src = """
233+
#loc = loc("/path/test.py":7:0)
234+
module {
235+
tt.func public @test(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/path/test.py":7:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/path/test.py":7:0)) attributes {noinline = false} {
236+
%0 = tt.load %arg0 : !tt.ptr<f32> loc(#loc1)
237+
tt.store %arg1, %0 : !tt.ptr<f32> loc(#loc2)
238+
tt.return loc(#loc3)
239+
} loc(#loc)
240+
} loc(#loc)
241+
#loc1 = loc("/path/test.py":8:16)
242+
#loc2 = loc("/path/test.py":9:20)
243+
#loc3 = loc("/path/test.py":9:4)
244+
"""
245+
monkeypatch.setenv("USE_IR_LOC", status)
246+
temp_file = tmp_path / "test.ttir"
247+
temp_file.write_text(src)
248+
kernel_info = triton.compile(str(temp_file))
249+
file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind])
250+
if status == "ttir":
251+
assert check_file_lines(file_lines, "/path/test.py", 8, should_contain=False)
252+
assert check_file_lines(file_lines, str(temp_file), -1, should_contain=True)
253+
else:
254+
assert check_file_lines(file_lines, "/path/test.py", 8, should_contain=True)

python/triton/compiler/compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,10 @@ def compile(src, target=None, options=None):
288288
filter_traceback(e)
289289
raise
290290
use_ir_loc = os.environ.get("USE_IR_LOC", None)
291+
if ir_source and use_ir_loc:
292+
module.create_location_snapshot(src.path)
293+
print(f"Creating new locations for {src.path}")
294+
291295
for ext, compile_ir in list(stages.items())[first_stage:]:
292296
next_module = compile_ir(module, metadata)
293297
ir_filename = f"{file_name}.{ext}"

0 commit comments

Comments
 (0)