Skip to content

Commit fc73bb3

Browse files
FindHaofacebook-github-bot
authored andcommitted
Wire torch trace dir through parse pipeline and CLI
Summary: Add --torch-trace-dir CLI parameter and wire it through the full pipeline: unified_parse() -> oss_run() -> parse_logs() -> parse_single_file(). Includes auto-discovery logic: when --torch-trace-dir is not specified, torch trace log files are automatically searched in the same directory as tritonparse logs. This enables kernel compile attribution in multi-process scenarios without requiring explicit user configuration. Also exports CompileInfo, discover_torch_trace_files, and parse_torch_trace_logs from the parse module's public API. Differential Revision: D95080073
1 parent 56a801e commit fc73bb3

File tree

4 files changed

+261
-2
lines changed

4 files changed

+261
-2
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
"""Integration tests for the full parse pipeline with torch trace log support."""
3+
4+
import json
5+
import os
6+
import tempfile
7+
import unittest
8+
9+
from tritonparse.parse.common import (
10+
_build_kernel_compile_mapping,
11+
parse_logs,
12+
RankConfig,
13+
)
14+
15+
16+
def _make_glog_line(metadata_dict: dict) -> str:
17+
"""Helper to create a glog-formatted line with JSON metadata."""
18+
return f"V0302 14:30:00.123456 12345 torch/_logging/_internal.py:1489] {json.dumps(metadata_dict)}"
19+
20+
21+
def _make_torch_trace_log(frame_id, frame_compile_id, kernel_paths):
22+
"""Create content for a torch trace log file with inductor_output_code event."""
23+
metadata = {
24+
"inductor_output_code": {
25+
"filename": "output.py",
26+
"file_path": "/tmp/output.py",
27+
},
28+
"frame_id": frame_id,
29+
"frame_compile_id": frame_compile_id,
30+
"attempt": 0,
31+
"has_payload": "abc123",
32+
}
33+
lines = [_make_glog_line(metadata)]
34+
for kp in kernel_paths:
35+
lines.append(f"\t# kernel path: {kp}")
36+
lines.append("\ttriton_kernel = async_compile.triton('kernel', '''...''')")
37+
return "\n".join(lines) + "\n"
38+
39+
40+
def _make_tritonparse_trace(events):
41+
"""Create content for a tritonparse trace NDJSON file."""
42+
return "\n".join(json.dumps(e) for e in events) + "\n"
43+
44+
45+
class TestBuildKernelCompileMapping(unittest.TestCase):
46+
"""Tests for _build_kernel_compile_mapping."""
47+
48+
def test_auto_discover_in_raw_log_dir(self):
49+
"""Test that torch trace files are auto-discovered in the raw log directory."""
50+
with tempfile.TemporaryDirectory() as tmpdir:
51+
# Create a torch trace log file in the same directory
52+
torch_log = os.path.join(tmpdir, "dedicated_log_torch_trace_rank_0_abc.log")
53+
content = _make_torch_trace_log(
54+
frame_id=0,
55+
frame_compile_id=0,
56+
kernel_paths=["/tmp/torchinductor_user/ab/kernel.py"],
57+
)
58+
with open(torch_log, "w") as f:
59+
f.write(content)
60+
61+
mapping = _build_kernel_compile_mapping(tmpdir)
62+
self.assertIsNotNone(mapping)
63+
self.assertIn("/tmp/torchinductor_user/ab/kernel.py", mapping)
64+
self.assertEqual(
65+
mapping["/tmp/torchinductor_user/ab/kernel.py"].frame_id, 0
66+
)
67+
68+
def test_explicit_torch_trace_dir(self):
69+
"""Test using an explicit torch_trace_dir."""
70+
with tempfile.TemporaryDirectory() as log_dir:
71+
with tempfile.TemporaryDirectory() as torch_dir:
72+
torch_log = os.path.join(
73+
torch_dir, "dedicated_log_torch_trace_rank_0_abc.log"
74+
)
75+
content = _make_torch_trace_log(
76+
frame_id=1,
77+
frame_compile_id=0,
78+
kernel_paths=["/tmp/torchinductor_user/cd/kernel.py"],
79+
)
80+
with open(torch_log, "w") as f:
81+
f.write(content)
82+
83+
mapping = _build_kernel_compile_mapping(log_dir, torch_dir)
84+
self.assertIsNotNone(mapping)
85+
self.assertEqual(len(mapping), 1)
86+
87+
def test_no_torch_trace_files(self):
88+
"""Test that None is returned when no torch trace files exist."""
89+
with tempfile.TemporaryDirectory() as tmpdir:
90+
mapping = _build_kernel_compile_mapping(tmpdir)
91+
self.assertIsNone(mapping)
92+
93+
94+
class TestParseLogsWithTorchTrace(unittest.TestCase):
95+
"""End-to-end test for parse_logs with torch trace integration."""
96+
97+
def test_end_to_end_mapping(self):
98+
"""Test that kernels without pt_info are correctly attributed via torch trace logs."""
99+
kernel_path = "/tmp/torchinductor_user/ab/cabcdef.py"
100+
101+
with tempfile.TemporaryDirectory() as tmpdir:
102+
# Create torch trace log
103+
torch_log_path = os.path.join(
104+
tmpdir, "dedicated_log_torch_trace_rank_0_test.log"
105+
)
106+
torch_content = _make_torch_trace_log(
107+
frame_id=3,
108+
frame_compile_id=1,
109+
kernel_paths=[kernel_path],
110+
)
111+
with open(torch_log_path, "w") as f:
112+
f.write(torch_content)
113+
114+
# Create tritonparse trace log (compilation without pt_info)
115+
triton_events = [
116+
{
117+
"event_type": "compilation",
118+
"pid": 1000,
119+
"stack": [],
120+
"payload": {
121+
"metadata": {"hash": "test_hash", "name": "test_kernel"},
122+
"file_content": {},
123+
"file_path": {},
124+
"python_source": {"file_path": kernel_path},
125+
# No pt_info — multi-process scenario
126+
},
127+
},
128+
{
129+
"event_type": "launch",
130+
"name": "test_kernel",
131+
"pid": 1000,
132+
"stack": [],
133+
"compilation_metadata": {"hash": "test_hash"},
134+
},
135+
]
136+
triton_log_path = os.path.join(
137+
tmpdir, "dedicated_log_triton_trace_user_.ndjson"
138+
)
139+
with open(triton_log_path, "w") as f:
140+
f.write(_make_tritonparse_trace(triton_events))
141+
142+
# Run parse_logs (use all_ranks=True to pick up no-rank files)
143+
rank_config = RankConfig(all_ranks=True)
144+
parsed_dir, file_mapping = parse_logs(
145+
tmpdir,
146+
rank_config,
147+
verbose=False,
148+
split_inductor_compilations=True,
149+
)
150+
151+
# Check that the output was split into a frame-specific file
152+
# Walk the output directory to find all generated files
153+
all_files = []
154+
for _root, _dirs, files in os.walk(parsed_dir):
155+
for f in files:
156+
all_files.append(f)
157+
158+
frame_files = [f for f in all_files if f.startswith("f")]
159+
# Should have f3_fc1_a0_cai-.ndjson.gz (attributed via mapping)
160+
self.assertTrue(
161+
any("f3_fc1" in f for f in frame_files),
162+
f"Expected frame file with f3_fc1 but got: {all_files}",
163+
)
164+
165+
166+
if __name__ == "__main__":
167+
unittest.main()

tritonparse/parse/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
create_python_mapping,
2929
)
3030
from .source_type import Source, SourceType
31+
from .torch_trace_parser import (
32+
CompileInfo,
33+
discover_torch_trace_files,
34+
parse_torch_trace_logs,
35+
)
3136
from .trace_processor import (
3237
generate_source_mappings,
3338
parse_single_file,
@@ -55,6 +60,10 @@
5560
"_add_parse_args",
5661
"oss_run",
5762
"unified_parse",
63+
# Torch trace parser
64+
"CompileInfo",
65+
"discover_torch_trace_files",
66+
"parse_torch_trace_logs",
5867
# IR parsing
5968
"extract_code_locations",
6069
"extract_loc_definitions",

tritonparse/parse/common.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import tempfile
99
from collections import defaultdict
1010
from pathlib import Path
11-
from typing import Optional, Tuple
11+
from typing import Any, Dict, List, Optional, Tuple
1212

1313
from tritonparse.shared_vars import (
1414
DEFAULT_TRACE_FILE_PREFIX_WITHOUT_USER as LOG_PREFIX,
@@ -312,12 +312,66 @@ def copy_local_to_tmpdir(local_path: str, verbose: bool = False) -> str:
312312
return temp_dir
313313

314314

315+
def _build_kernel_compile_mapping(
316+
raw_log_dir: str,
317+
torch_trace_dir: Optional[str] = None,
318+
) -> Optional[Dict[str, Any]]:
319+
"""
320+
Build kernel compile mapping from inductor's torch trace logs.
321+
322+
Searches for torch trace log files and parses them to extract
323+
kernel_source_path -> CompileInfo mappings. These mappings allow
324+
attribution of Triton kernels to their originating compilation frame
325+
when pt_info is missing (multi-process Triton JIT scenarios).
326+
327+
Args:
328+
raw_log_dir: Directory containing tritonparse logs (used for auto-discovery).
329+
torch_trace_dir: Explicit directory containing torch trace logs.
330+
If None, auto-discovers in raw_log_dir.
331+
332+
Returns:
333+
Dict mapping kernel source paths to CompileInfo, or None if no logs found.
334+
"""
335+
from .torch_trace_parser import discover_torch_trace_files, parse_torch_trace_logs
336+
337+
# Determine where to look for torch trace logs
338+
search_dirs = []
339+
if torch_trace_dir:
340+
search_dirs.append(torch_trace_dir)
341+
# Also check the raw log directory (torch trace logs may coexist)
342+
search_dirs.append(raw_log_dir)
343+
344+
all_log_paths: List[str] = []
345+
seen_paths: set = set()
346+
for search_dir in search_dirs:
347+
if not os.path.isdir(search_dir):
348+
continue
349+
torch_files = discover_torch_trace_files(search_dir)
350+
for rank_files in torch_files.values():
351+
for path in rank_files:
352+
if path not in seen_paths:
353+
all_log_paths.append(path)
354+
seen_paths.add(path)
355+
356+
if not all_log_paths:
357+
return None
358+
359+
mapping = parse_torch_trace_logs(all_log_paths)
360+
if mapping:
361+
logger.info(
362+
f"Built kernel compile mapping with {len(mapping)} entries "
363+
f"from {len(all_log_paths)} torch trace log(s)"
364+
)
365+
return mapping if mapping else None
366+
367+
315368
def parse_logs(
316369
logs_to_parse: str,
317370
rank_config: RankConfig,
318371
verbose: bool = False,
319372
tritonparse_url_prefix: str = "",
320373
split_inductor_compilations: bool = True,
374+
torch_trace_dir: Optional[str] = None,
321375
) -> Tuple[str, dict]:
322376
"""
323377
Parse logs.
@@ -330,6 +384,10 @@ def parse_logs(
330384
split_inductor_compilations: Whether to split
331385
output files by frame_id, compile_id, attempt_id, and compiled_autograd_id.
332386
Defaults to True. This rule follows tlparse's behavior.
387+
torch_trace_dir: Optional path to directory containing inductor torch trace
388+
logs. When provided, kernel compilation attribution will use these logs to
389+
recover frame_id/compile_id for kernels compiled in multi-process scenarios.
390+
If None, auto-discovers torch trace files in the same directory as tritonparse logs.
333391
Returns:
334392
Tuple of (parsed log directory, file mapping)
335393
"""
@@ -372,6 +430,10 @@ def parse_logs(
372430
ranks[Rank(Rank.NO_RANK)].append(path)
373431
if not ranks:
374432
raise RuntimeError(f"No eligible structured trace logs found in {raw_log_dir}")
433+
434+
# Build kernel compile mapping from torch trace logs (if available)
435+
kernel_compile_mapping = _build_kernel_compile_mapping(raw_log_dir, torch_trace_dir)
436+
375437
file_mapping = {"tritonparse_url_prefix": tritonparse_url_prefix}
376438
# Parse each eligible log
377439
for rank, files in ranks.items():
@@ -406,7 +468,12 @@ def parse_logs(
406468
relative_path = "" if rank.is_no_rank else rank.to_string("")
407469
output_dir = os.path.join(parsed_log_dir, relative_path)
408470
# Parse the file
409-
parse_single_file(input_file, output_dir, split_inductor_compilations)
471+
parse_single_file(
472+
input_file,
473+
output_dir,
474+
split_inductor_compilations,
475+
kernel_compile_mapping=kernel_compile_mapping,
476+
)
410477
# Collect generated files after parsing and gzip them immediately
411478
if os.path.exists(output_dir):
412479
generated_files = []

tritonparse/parse/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ def _add_parse_args(parser: argparse.ArgumentParser) -> None:
7676
action="store_true",
7777
)
7878
parser.add_argument("-v", "--verbose", help="Verbose logging", action="store_true")
79+
parser.add_argument(
80+
"--torch-trace-dir",
81+
type=str,
82+
default=None,
83+
help=(
84+
"Path to directory containing inductor torch trace logs. "
85+
"Used to recover kernel compilation attribution in multi-process scenarios. "
86+
"If not specified, auto-discovers torch trace files alongside tritonparse logs."
87+
),
88+
)
7989
if is_fbcode():
8090
from tritonparse.fb.utils import append_parser
8191

@@ -91,6 +101,7 @@ def oss_run(
91101
verbose: bool = False,
92102
split_inductor_compilations: bool = True,
93103
skip_logger: bool = True,
104+
torch_trace_dir: Optional[str] = None,
94105
):
95106
"""
96107
Main function for tritonparse. It is for OSS only.
@@ -103,6 +114,7 @@ def oss_run(
103114
all_ranks: Analyze all ranks
104115
verbose: Verbose logging
105116
skip_logger: Unused in OSS, kept for API compatibility.
117+
torch_trace_dir: Path to directory containing inductor torch trace logs.
106118
"""
107119
source = Source(source, verbose)
108120
rank_config = RankConfig.from_cli_args(rank, all_ranks, source.type)
@@ -137,6 +149,7 @@ def oss_run(
137149
rank_config,
138150
verbose,
139151
split_inductor_compilations=split_inductor_compilations,
152+
torch_trace_dir=torch_trace_dir,
140153
)
141154
else:
142155
parsed_log_dir = source.value
@@ -161,6 +174,7 @@ def unified_parse(
161174
verbose: bool = False,
162175
split_inductor_compilations: bool = True,
163176
skip_logger: bool = False,
177+
torch_trace_dir: Optional[str] = None,
164178
**kwargs,
165179
):
166180
"""
@@ -174,6 +188,7 @@ def unified_parse(
174188
all_ranks: Whether to analyze all ranks
175189
verbose: Whether to enable verbose logging
176190
skip_logger: Whether to skip usage logging (default: False).
191+
torch_trace_dir: Path to directory containing inductor torch trace logs.
177192
"""
178193
# Log usage for API invocations
179194
if not skip_logger and is_fbcode():
@@ -196,6 +211,7 @@ def unified_parse(
196211
verbose=verbose,
197212
split_inductor_compilations=split_inductor_compilations,
198213
skip_logger=skip_logger,
214+
torch_trace_dir=torch_trace_dir,
199215
**kwargs,
200216
)
201217
return output

0 commit comments

Comments
 (0)