diff --git a/tests/cpu/test_kernel_attribution.py b/tests/cpu/test_kernel_attribution.py new file mode 100644 index 0000000..6dbab1d --- /dev/null +++ b/tests/cpu/test_kernel_attribution.py @@ -0,0 +1,385 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +"""Tests for kernel attribution via compile mapping in trace_processor.""" + +import json +import os +import tempfile +import unittest + +from tritonparse.parse.torch_trace_parser import CompileInfo +from tritonparse.parse.trace_processor import ( + _determine_output_fname, + _resolve_compile_info, + parse_single_file, +) + + +class TestResolveCompileInfo(unittest.TestCase): + """Tests for _resolve_compile_info.""" + + def _make_mapping(self): + return { + "/tmp/torchinductor_user/ab/kernel1.py": CompileInfo( + frame_id=0, frame_compile_id=0 + ), + "/tmp/torchinductor_user/cd/kernel2.py": CompileInfo( + frame_id=1, frame_compile_id=0, attempt=1 + ), + } + + def test_resolve_via_python_source(self): + """Test resolution via python_source.file_path.""" + event = { + "payload": { + "python_source": {"file_path": "/tmp/torchinductor_user/ab/kernel1.py"} + }, + "stack": [], + } + result = _resolve_compile_info(event, self._make_mapping()) + self.assertIsNotNone(result) + self.assertEqual(result.frame_id, 0) + self.assertEqual(result.frame_compile_id, 0) + + def test_resolve_via_stack_trace(self): + """Test resolution via stack trace when python_source is missing.""" + event = { + "payload": {}, + "stack": [ + {"filename": "/user/code.py", "line": 10, "name": "main"}, + { + "filename": "/tmp/torchinductor_user/cd/kernel2.py", + "line": 1, + "name": "kernel", + }, + {"filename": "triton/jit.py", "line": 50, "name": "run"}, + ], + } + result = _resolve_compile_info(event, self._make_mapping()) + self.assertIsNotNone(result) + self.assertEqual(result.frame_id, 1) + self.assertEqual(result.frame_compile_id, 0) + self.assertEqual(result.attempt, 1) + + def test_no_match(self): + """Test that None is returned when no match is found.""" + event = { + "payload": {"python_source": {"file_path": "/tmp/unknown/path.py"}}, + "stack": [{"filename": "/user/code.py", "line": 10, "name": "main"}], + } + result = _resolve_compile_info(event, self._make_mapping()) + self.assertIsNone(result) + + def test_empty_event(self): + """Test with minimal event data.""" + event = {} + result = _resolve_compile_info(event, self._make_mapping()) + self.assertIsNone(result) + + def test_python_source_takes_priority(self): + """Test that python_source.file_path is preferred over stack trace.""" + event = { + "payload": { + "python_source": {"file_path": "/tmp/torchinductor_user/ab/kernel1.py"} + }, + "stack": [ + { + "filename": "/tmp/torchinductor_user/cd/kernel2.py", + "line": 1, + "name": "kernel", + } + ], + } + result = _resolve_compile_info(event, self._make_mapping()) + # Should use python_source path (kernel1 -> frame_id=0), not stack (kernel2 -> frame_id=1) + self.assertEqual(result.frame_id, 0) + + +class TestDetermineOutputFname(unittest.TestCase): + """Tests for _determine_output_fname.""" + + def test_with_pt_info(self): + """Test normal case where pt_info has frame_id/compile_id.""" + fname = _determine_output_fname( + pt_info={"frame_id": 0, "frame_compile_id": 1, "attempt_id": 0}, + file_name_without_extension="trace", + split_inductor_compilations=True, + ) + self.assertEqual(fname, "f0_fc1_a0_cai-.ndjson") + + def test_without_pt_info_no_mapping(self): + """Test fallback to mapped file when pt_info is missing and no mapping.""" + fname = _determine_output_fname( + pt_info={}, + file_name_without_extension="trace", + split_inductor_compilations=True, + ) + self.assertEqual(fname, "trace_mapped.ndjson") + + def test_without_pt_info_with_mapping(self): + """Test resolution via mapping when pt_info is missing.""" + mapping = { + "/tmp/torchinductor_user/ab/kernel.py": CompileInfo( + frame_id=3, frame_compile_id=2, attempt=1, compiled_autograd_id=5 + ) + } + event = { + "payload": { + "python_source": {"file_path": "/tmp/torchinductor_user/ab/kernel.py"} + }, + "stack": [], + } + fname = _determine_output_fname( + pt_info={}, + file_name_without_extension="trace", + split_inductor_compilations=True, + event=event, + kernel_compile_mapping=mapping, + ) + self.assertEqual(fname, "f3_fc2_a1_cai5.ndjson") + + def test_split_disabled(self): + """Test that splitting disabled always returns mapped filename.""" + fname = _determine_output_fname( + pt_info={"frame_id": 0, "frame_compile_id": 0}, + file_name_without_extension="trace", + split_inductor_compilations=False, + ) + self.assertEqual(fname, "trace_mapped.ndjson") + + def test_compiled_autograd_id_none(self): + """Test that compiled_autograd_id defaults to '-' when not set.""" + fname = _determine_output_fname( + pt_info={"frame_id": 0, "frame_compile_id": 0}, + file_name_without_extension="trace", + split_inductor_compilations=True, + ) + self.assertEqual(fname, "f0_fc0_a0_cai-.ndjson") + + +class TestParseSingleFileWithMapping(unittest.TestCase): + """Integration tests for parse_single_file with kernel_compile_mapping.""" + + def test_mapping_redirects_compilation_to_frame_file(self): + """Test that a compilation without pt_info is redirected when mapping is provided.""" + kernel_path = "/tmp/torchinductor_user/ab/cabcdef1234.py" + mapping = { + kernel_path: CompileInfo(frame_id=2, frame_compile_id=1), + } + + trace_lines = [ + json.dumps( + { + "event_type": "compilation", + "pid": 1000, + "stack": [], + "payload": { + "metadata": {"hash": "kernel_hash_1", "name": "kernel_1"}, + "file_content": {}, + "file_path": {}, + "python_source": {"file_path": kernel_path}, + # No pt_info — this is the multi-process scenario + }, + } + ), + json.dumps( + { + "event_type": "launch", + "name": "kernel_1", + "pid": 1000, + "stack": [], + "compilation_metadata": {"hash": "kernel_hash_1"}, + } + ), + ] + + with tempfile.TemporaryDirectory() as temp_dir: + input_file = os.path.join(temp_dir, "test_trace.ndjson") + with open(input_file, "w") as f: + for line in trace_lines: + f.write(line + "\n") + + output_dir = os.path.join(temp_dir, "output") + os.makedirs(output_dir) + + parse_single_file(input_file, output_dir, kernel_compile_mapping=mapping) + + output_files = os.listdir(output_dir) + # Should produce a frame-specific file, not _mapped + frame_files = [f for f in output_files if f.startswith("f")] + mapped_files = [f for f in output_files if "mapped" in f] + self.assertEqual(len(frame_files), 1) + self.assertEqual(len(mapped_files), 0) + self.assertEqual(frame_files[0], "f2_fc1_a0_cai-.ndjson") + + def test_no_mapping_falls_back_to_mapped(self): + """Test that without mapping, compilations without pt_info go to _mapped.""" + trace_lines = [ + json.dumps( + { + "event_type": "compilation", + "pid": 1000, + "stack": [], + "payload": { + "metadata": {"hash": "kernel_hash_1", "name": "kernel_1"}, + "file_content": {}, + "file_path": {}, + # No pt_info, no python_source + }, + } + ), + json.dumps( + { + "event_type": "launch", + "name": "kernel_1", + "pid": 1000, + "stack": [], + "compilation_metadata": {"hash": "kernel_hash_1"}, + } + ), + ] + + with tempfile.TemporaryDirectory() as temp_dir: + input_file = os.path.join(temp_dir, "test_trace.ndjson") + with open(input_file, "w") as f: + for line in trace_lines: + f.write(line + "\n") + + output_dir = os.path.join(temp_dir, "output") + os.makedirs(output_dir) + + parse_single_file(input_file, output_dir) + + output_files = os.listdir(output_dir) + mapped_files = [f for f in output_files if "mapped" in f] + self.assertGreater(len(mapped_files), 0) + + def test_mixed_with_and_without_pt_info(self): + """Test a mix of events: some with pt_info, some resolved via mapping.""" + kernel_path_a = "/tmp/torchinductor_user/ab/kernel_a.py" + mapping = { + kernel_path_a: CompileInfo(frame_id=0, frame_compile_id=0), + } + + trace_lines = [ + # Compilation WITH pt_info (should be split normally) + json.dumps( + { + "event_type": "compilation", + "pid": 1000, + "stack": [], + "payload": { + "metadata": {"hash": "hash_with_pt", "name": "kernel_with_pt"}, + "file_content": {}, + "file_path": {}, + "pt_info": { + "frame_id": 1, + "frame_compile_id": 0, + }, + }, + } + ), + json.dumps( + { + "event_type": "launch", + "name": "kernel_with_pt", + "pid": 1000, + "stack": [], + "compilation_metadata": {"hash": "hash_with_pt"}, + } + ), + # Compilation WITHOUT pt_info (should be resolved via mapping) + json.dumps( + { + "event_type": "compilation", + "pid": 1000, + "stack": [], + "payload": { + "metadata": { + "hash": "hash_without_pt", + "name": "kernel_without_pt", + }, + "file_content": {}, + "file_path": {}, + "python_source": {"file_path": kernel_path_a}, + }, + } + ), + json.dumps( + { + "event_type": "launch", + "name": "kernel_without_pt", + "pid": 1000, + "stack": [], + "compilation_metadata": {"hash": "hash_without_pt"}, + } + ), + ] + + with tempfile.TemporaryDirectory() as temp_dir: + input_file = os.path.join(temp_dir, "test_trace.ndjson") + with open(input_file, "w") as f: + for line in trace_lines: + f.write(line + "\n") + + output_dir = os.path.join(temp_dir, "output") + os.makedirs(output_dir) + + parse_single_file(input_file, output_dir, kernel_compile_mapping=mapping) + + output_files = sorted(os.listdir(output_dir)) + # Should have two frame files: f0_fc0 and f1_fc0 + frame_files = sorted([f for f in output_files if f.startswith("f")]) + self.assertEqual(len(frame_files), 2) + self.assertIn("f0_fc0_a0_cai-.ndjson", frame_files) + self.assertIn("f1_fc0_a0_cai-.ndjson", frame_files) + + def test_fake_compilation_with_mapping(self): + """Test that fake compilations can also be attributed via stack trace mapping.""" + kernel_path = "/tmp/torchinductor_user/ab/kernel_fake.py" + mapping = { + kernel_path: CompileInfo(frame_id=5, frame_compile_id=0), + } + + trace_lines = [ + # Only launch event (will trigger fake compilation) + json.dumps( + { + "event_type": "launch", + "name": "fake_kernel", + "pid": 1000, + "stack": [ + {"filename": "/user/code.py", "line": 10, "name": "main"}, + { + "filename": kernel_path, + "line": 1, + "name": "kernel_fn", + }, + ], + "compilation_metadata": { + "hash": "fake_hash", + "name": "fake_kernel", + "num_warps": 4, + }, + } + ), + ] + + with tempfile.TemporaryDirectory() as temp_dir: + input_file = os.path.join(temp_dir, "test_trace.ndjson") + with open(input_file, "w") as f: + for line in trace_lines: + f.write(line + "\n") + + output_dir = os.path.join(temp_dir, "output") + os.makedirs(output_dir) + + parse_single_file(input_file, output_dir, kernel_compile_mapping=mapping) + + output_files = os.listdir(output_dir) + frame_files = [f for f in output_files if f.startswith("f")] + self.assertEqual(len(frame_files), 1) + self.assertEqual(frame_files[0], "f5_fc0_a0_cai-.ndjson") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/cpu/test_pipeline_integration.py b/tests/cpu/test_pipeline_integration.py new file mode 100644 index 0000000..164c537 --- /dev/null +++ b/tests/cpu/test_pipeline_integration.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +"""Integration tests for the full parse pipeline with torch trace log support.""" + +import json +import os +import tempfile +import unittest + +from tritonparse.parse.common import ( + _build_kernel_compile_mapping, + parse_logs, + RankConfig, +) + + +def _make_glog_line(metadata_dict: dict) -> str: + """Helper to create a glog-formatted line with JSON metadata.""" + return f"V0302 14:30:00.123456 12345 torch/_logging/_internal.py:1489] {json.dumps(metadata_dict)}" + + +def _make_torch_trace_log(frame_id, frame_compile_id, kernel_paths): + """Create content for a torch trace log file with inductor_output_code event.""" + metadata = { + "inductor_output_code": { + "filename": "output.py", + "file_path": "/tmp/output.py", + }, + "frame_id": frame_id, + "frame_compile_id": frame_compile_id, + "attempt": 0, + "has_payload": "abc123", + } + lines = [_make_glog_line(metadata)] + for kp in kernel_paths: + lines.append(f"\t# kernel path: {kp}") + lines.append("\ttriton_kernel = async_compile.triton('kernel', '''...''')") + return "\n".join(lines) + "\n" + + +def _make_tritonparse_trace(events): + """Create content for a tritonparse trace NDJSON file.""" + return "\n".join(json.dumps(e) for e in events) + "\n" + + +class TestBuildKernelCompileMapping(unittest.TestCase): + """Tests for _build_kernel_compile_mapping.""" + + def test_auto_discover_in_raw_log_dir(self): + """Test that torch trace files are auto-discovered in the raw log directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create a torch trace log file in the same directory + torch_log = os.path.join(tmpdir, "dedicated_log_torch_trace_rank_0_abc.log") + content = _make_torch_trace_log( + frame_id=0, + frame_compile_id=0, + kernel_paths=["/tmp/torchinductor_user/ab/kernel.py"], + ) + with open(torch_log, "w") as f: + f.write(content) + + mapping = _build_kernel_compile_mapping(tmpdir) + self.assertIsNotNone(mapping) + self.assertIn("/tmp/torchinductor_user/ab/kernel.py", mapping) + self.assertEqual( + mapping["/tmp/torchinductor_user/ab/kernel.py"].frame_id, 0 + ) + + def test_explicit_torch_trace_dir(self): + """Test using an explicit torch_trace_dir.""" + with tempfile.TemporaryDirectory() as log_dir: + with tempfile.TemporaryDirectory() as torch_dir: + torch_log = os.path.join( + torch_dir, "dedicated_log_torch_trace_rank_0_abc.log" + ) + content = _make_torch_trace_log( + frame_id=1, + frame_compile_id=0, + kernel_paths=["/tmp/torchinductor_user/cd/kernel.py"], + ) + with open(torch_log, "w") as f: + f.write(content) + + mapping = _build_kernel_compile_mapping(log_dir, torch_dir) + self.assertIsNotNone(mapping) + self.assertEqual(len(mapping), 1) + + def test_no_torch_trace_files(self): + """Test that None is returned when no torch trace files exist.""" + with tempfile.TemporaryDirectory() as tmpdir: + mapping = _build_kernel_compile_mapping(tmpdir) + self.assertIsNone(mapping) + + +class TestParseLogsWithTorchTrace(unittest.TestCase): + """End-to-end test for parse_logs with torch trace integration.""" + + def test_end_to_end_mapping(self): + """Test that kernels without pt_info are correctly attributed via torch trace logs.""" + kernel_path = "/tmp/torchinductor_user/ab/cabcdef.py" + + with tempfile.TemporaryDirectory() as tmpdir: + # Create torch trace log + torch_log_path = os.path.join( + tmpdir, "dedicated_log_torch_trace_rank_0_test.log" + ) + torch_content = _make_torch_trace_log( + frame_id=3, + frame_compile_id=1, + kernel_paths=[kernel_path], + ) + with open(torch_log_path, "w") as f: + f.write(torch_content) + + # Create tritonparse trace log (compilation without pt_info) + triton_events = [ + { + "event_type": "compilation", + "pid": 1000, + "stack": [], + "payload": { + "metadata": {"hash": "test_hash", "name": "test_kernel"}, + "file_content": {}, + "file_path": {}, + "python_source": {"file_path": kernel_path}, + # No pt_info — multi-process scenario + }, + }, + { + "event_type": "launch", + "name": "test_kernel", + "pid": 1000, + "stack": [], + "compilation_metadata": {"hash": "test_hash"}, + }, + ] + triton_log_path = os.path.join( + tmpdir, "dedicated_log_triton_trace_user_.ndjson" + ) + with open(triton_log_path, "w") as f: + f.write(_make_tritonparse_trace(triton_events)) + + # Run parse_logs (use all_ranks=True to pick up no-rank files) + rank_config = RankConfig(all_ranks=True) + parsed_dir, file_mapping = parse_logs( + tmpdir, + rank_config, + verbose=False, + split_inductor_compilations=True, + ) + + # Check that the output was split into a frame-specific file + # Walk the output directory to find all generated files + all_files = [] + for _root, _dirs, files in os.walk(parsed_dir): + for f in files: + all_files.append(f) + + frame_files = [f for f in all_files if f.startswith("f")] + # Should have f3_fc1_a0_cai-.ndjson.gz (attributed via mapping) + self.assertTrue( + any("f3_fc1" in f for f in frame_files), + f"Expected frame file with f3_fc1 but got: {all_files}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/cpu/test_torch_trace_parser.py b/tests/cpu/test_torch_trace_parser.py new file mode 100644 index 0000000..0d8087e --- /dev/null +++ b/tests/cpu/test_torch_trace_parser.py @@ -0,0 +1,361 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +"""Tests for torch trace log parser.""" + +import json +import os +import tempfile +import unittest + +from tritonparse.parse.torch_trace_parser import ( + _extract_json_from_glog_line, + _parse_torch_trace_log, + CompileInfo, + discover_torch_trace_files, + parse_torch_trace_logs, +) + + +def _make_glog_line(metadata_dict: dict) -> str: + """Helper to create a glog-formatted line with JSON metadata.""" + return f"V0302 14:30:00.123456 12345 torch/_logging/_internal.py:1489] {json.dumps(metadata_dict)}" + + +def _make_output_code_event( + frame_id: int = 0, + frame_compile_id: int = 0, + attempt: int = 0, + compiled_autograd_id: int = None, + kernel_paths: list = None, +) -> str: + """Helper to create a complete inductor_output_code log record (header + payload).""" + metadata = { + "inductor_output_code": { + "filename": "/tmp/torchinductor_user/ab/test.py", + "file_path": "/tmp/torchinductor_user/ab/test.py", + }, + "frame_id": frame_id, + "frame_compile_id": frame_compile_id, + "attempt": attempt, + "has_payload": "abc123", + } + if compiled_autograd_id is not None: + metadata["compiled_autograd_id"] = compiled_autograd_id + + lines = [_make_glog_line(metadata)] + + # Build a minimal output_code.py payload with kernel path comments + if kernel_paths is None: + kernel_paths = ["/tmp/torchinductor_user/ab/cabcdef1234.py"] + + for kp in kernel_paths: + lines.append(f"\t# kernel path: {kp}") + lines.append("\t# Source Nodes: [add], Original ATen: [aten.add]") + lines.append( + "\ttriton_poi_fused_add_0 = async_compile.triton('triton_poi_fused_add_0', '''" + ) + lines.append("\t# kernel code here") + lines.append("\t''')") + + return "\n".join(lines) + + +class TestExtractJsonFromGlogLine(unittest.TestCase): + """Tests for _extract_json_from_glog_line.""" + + def test_valid_glog_line(self): + line = 'V0302 14:30:00.123456 12345 path.py:100] {"key": "value"}' + result = _extract_json_from_glog_line(line) + self.assertEqual(result, '{"key": "value"}') + + def test_no_bracket(self): + line = "some random text without bracket" + result = _extract_json_from_glog_line(line) + self.assertIsNone(result) + + def test_bracket_without_space(self): + line = "]no_space" + result = _extract_json_from_glog_line(line) + self.assertIsNone(result) + + +class TestCompileInfo(unittest.TestCase): + """Tests for CompileInfo dataclass.""" + + def test_defaults(self): + info = CompileInfo() + self.assertIsNone(info.frame_id) + self.assertIsNone(info.frame_compile_id) + self.assertEqual(info.attempt, 0) + self.assertIsNone(info.compiled_autograd_id) + + def test_with_values(self): + info = CompileInfo( + frame_id=1, frame_compile_id=2, attempt=3, compiled_autograd_id=4 + ) + self.assertEqual(info.frame_id, 1) + self.assertEqual(info.frame_compile_id, 2) + self.assertEqual(info.attempt, 3) + self.assertEqual(info.compiled_autograd_id, 4) + + +class TestParseTorchTraceLog(unittest.TestCase): + """Tests for _parse_torch_trace_log.""" + + def _write_log(self, content: str) -> str: + """Write content to a temp file and return path.""" + f = tempfile.NamedTemporaryFile( + mode="w", suffix=".log", delete=False, prefix="test_torch_trace_" + ) + f.write(content) + f.close() + return f.name + + def test_single_kernel(self): + """Test parsing a log with a single inductor_output_code event.""" + content = _make_output_code_event( + frame_id=0, + frame_compile_id=0, + kernel_paths=["/tmp/torchinductor_user/ab/cabcdef1234.py"], + ) + log_path = self._write_log(content) + try: + mapping = _parse_torch_trace_log(log_path) + self.assertEqual(len(mapping), 1) + self.assertIn("/tmp/torchinductor_user/ab/cabcdef1234.py", mapping) + info = mapping["/tmp/torchinductor_user/ab/cabcdef1234.py"] + self.assertEqual(info.frame_id, 0) + self.assertEqual(info.frame_compile_id, 0) + self.assertEqual(info.attempt, 0) + self.assertIsNone(info.compiled_autograd_id) + finally: + os.unlink(log_path) + + def test_multiple_kernels_in_one_event(self): + """Test parsing an output_code event with multiple kernel paths.""" + content = _make_output_code_event( + frame_id=7, + frame_compile_id=3, + kernel_paths=[ + "/tmp/torchinductor_user/ab/kernel1.py", + "/tmp/torchinductor_user/cd/kernel2.py", + "/tmp/torchinductor_user/ef/kernel3.py", + ], + ) + log_path = self._write_log(content) + try: + mapping = _parse_torch_trace_log(log_path) + self.assertEqual(len(mapping), 3) + for kp in [ + "/tmp/torchinductor_user/ab/kernel1.py", + "/tmp/torchinductor_user/cd/kernel2.py", + "/tmp/torchinductor_user/ef/kernel3.py", + ]: + self.assertIn(kp, mapping) + self.assertEqual(mapping[kp].frame_id, 7) + self.assertEqual(mapping[kp].frame_compile_id, 3) + finally: + os.unlink(log_path) + + def test_multiple_events_different_frames(self): + """Test parsing multiple output_code events from different compilation frames.""" + event1 = _make_output_code_event( + frame_id=0, + frame_compile_id=0, + kernel_paths=["/tmp/torchinductor_user/ab/kernel_a.py"], + ) + event2 = _make_output_code_event( + frame_id=1, + frame_compile_id=0, + kernel_paths=["/tmp/torchinductor_user/cd/kernel_b.py"], + ) + # Non-output_code event in between + other_event = _make_glog_line( + {"dynamo_start": {"stack_index": 0}, "frame_id": 0} + ) + + content = event1 + "\n" + other_event + "\n" + event2 + log_path = self._write_log(content) + try: + mapping = _parse_torch_trace_log(log_path) + self.assertEqual(len(mapping), 2) + self.assertEqual( + mapping["/tmp/torchinductor_user/ab/kernel_a.py"].frame_id, 0 + ) + self.assertEqual( + mapping["/tmp/torchinductor_user/cd/kernel_b.py"].frame_id, 1 + ) + finally: + os.unlink(log_path) + + def test_with_compiled_autograd_id(self): + """Test parsing events with compiled_autograd_id.""" + content = _make_output_code_event( + frame_id=2, + frame_compile_id=1, + compiled_autograd_id=5, + kernel_paths=["/tmp/torchinductor_user/ab/kernel.py"], + ) + log_path = self._write_log(content) + try: + mapping = _parse_torch_trace_log(log_path) + info = mapping["/tmp/torchinductor_user/ab/kernel.py"] + self.assertEqual(info.compiled_autograd_id, 5) + finally: + os.unlink(log_path) + + def test_no_output_code_events(self): + """Test parsing a log with no inductor_output_code events.""" + content = _make_glog_line({"dynamo_start": {"stack_index": 0}}) + log_path = self._write_log(content) + try: + mapping = _parse_torch_trace_log(log_path) + self.assertEqual(len(mapping), 0) + finally: + os.unlink(log_path) + + def test_empty_file(self): + """Test parsing an empty log file.""" + log_path = self._write_log("") + try: + mapping = _parse_torch_trace_log(log_path) + self.assertEqual(len(mapping), 0) + finally: + os.unlink(log_path) + + def test_no_kernel_paths_in_payload(self): + """Test an output_code event with no kernel path comments in payload.""" + metadata = { + "inductor_output_code": {"filename": "test.py", "file_path": "test.py"}, + "frame_id": 0, + "frame_compile_id": 0, + } + content = _make_glog_line(metadata) + "\n\t# some other comment\n\tcode = 42\n" + log_path = self._write_log(content) + try: + mapping = _parse_torch_trace_log(log_path) + self.assertEqual(len(mapping), 0) + finally: + os.unlink(log_path) + + def test_nonexistent_file(self): + """Test parsing a file that doesn't exist.""" + mapping = _parse_torch_trace_log("/nonexistent/path/to/log.log") + self.assertEqual(len(mapping), 0) + + def test_malformed_json(self): + """Test that malformed JSON lines are skipped gracefully.""" + good_event = _make_output_code_event( + frame_id=0, + frame_compile_id=0, + kernel_paths=["/tmp/torchinductor_user/ab/kernel.py"], + ) + bad_line = "V0302 14:30:00.000000 999 path.py:1] {invalid json here" + content = bad_line + "\n" + good_event + log_path = self._write_log(content) + try: + mapping = _parse_torch_trace_log(log_path) + # Should still parse the good event + self.assertEqual(len(mapping), 1) + self.assertIn("/tmp/torchinductor_user/ab/kernel.py", mapping) + finally: + os.unlink(log_path) + + +class TestParseTorchTraceLogs(unittest.TestCase): + """Tests for parse_torch_trace_logs (multi-file).""" + + def test_merge_multiple_files(self): + """Test that mappings from multiple files are merged.""" + with tempfile.TemporaryDirectory() as tmpdir: + # File 1: rank 0 + content1 = _make_output_code_event( + frame_id=0, + frame_compile_id=0, + kernel_paths=["/tmp/torchinductor_user/ab/kernel_r0.py"], + ) + path1 = os.path.join(tmpdir, "log1.log") + with open(path1, "w") as f: + f.write(content1) + + # File 2: rank 1 + content2 = _make_output_code_event( + frame_id=0, + frame_compile_id=0, + kernel_paths=["/tmp/torchinductor_user/cd/kernel_r1.py"], + ) + path2 = os.path.join(tmpdir, "log2.log") + with open(path2, "w") as f: + f.write(content2) + + mapping = parse_torch_trace_logs([path1, path2]) + self.assertEqual(len(mapping), 2) + self.assertIn("/tmp/torchinductor_user/ab/kernel_r0.py", mapping) + self.assertIn("/tmp/torchinductor_user/cd/kernel_r1.py", mapping) + + +class TestDiscoverTorchTraceFiles(unittest.TestCase): + """Tests for discover_torch_trace_files.""" + + def test_discover_ranked_files(self): + """Test discovery of rank-specific torch trace files.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create mock torch trace files + for rank in [0, 1, 2]: + path = os.path.join( + tmpdir, f"dedicated_log_torch_trace_rank_{rank}_abc123.log" + ) + with open(path, "w") as f: + f.write("") + + result = discover_torch_trace_files(tmpdir) + self.assertEqual(len(result), 3) + self.assertIn(0, result) + self.assertIn(1, result) + self.assertIn(2, result) + + def test_discover_no_rank_file(self): + """Test discovery of a torch trace file without rank suffix.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "dedicated_log_torch_trace_abc123.log") + with open(path, "w") as f: + f.write("") + + result = discover_torch_trace_files(tmpdir) + self.assertEqual(len(result), 1) + self.assertIn(None, result) + + def test_ignores_non_torch_trace_files(self): + """Test that non-torch-trace files are ignored.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Tritonparse file (should be ignored) + with open( + os.path.join(tmpdir, "dedicated_log_triton_trace_user_.ndjson"), "w" + ) as f: + f.write("") + # Random file (should be ignored) + with open(os.path.join(tmpdir, "other_file.log"), "w") as f: + f.write("") + # Torch trace file (should be found) + with open( + os.path.join(tmpdir, "dedicated_log_torch_trace_rank_0_abc.log"), "w" + ) as f: + f.write("") + + result = discover_torch_trace_files(tmpdir) + self.assertEqual(len(result), 1) + self.assertIn(0, result) + + def test_empty_directory(self): + """Test discovery in an empty directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + result = discover_torch_trace_files(tmpdir) + self.assertEqual(len(result), 0) + + def test_nonexistent_directory(self): + """Test discovery in a directory that doesn't exist.""" + result = discover_torch_trace_files("/nonexistent/directory") + self.assertEqual(len(result), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tritonparse/parse/__init__.py b/tritonparse/parse/__init__.py index 064a0e2..88ab3df 100644 --- a/tritonparse/parse/__init__.py +++ b/tritonparse/parse/__init__.py @@ -28,6 +28,11 @@ create_python_mapping, ) from .source_type import Source, SourceType +from .torch_trace_parser import ( + CompileInfo, + discover_torch_trace_files, + parse_torch_trace_logs, +) from .trace_processor import ( generate_source_mappings, parse_single_file, @@ -55,6 +60,10 @@ "_add_parse_args", "oss_run", "unified_parse", + # Torch trace parser + "CompileInfo", + "discover_torch_trace_files", + "parse_torch_trace_logs", # IR parsing "extract_code_locations", "extract_loc_definitions", diff --git a/tritonparse/parse/common.py b/tritonparse/parse/common.py index ac0ce3a..0c6fa71 100644 --- a/tritonparse/parse/common.py +++ b/tritonparse/parse/common.py @@ -8,7 +8,7 @@ import tempfile from collections import defaultdict from pathlib import Path -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from tritonparse.shared_vars import ( DEFAULT_TRACE_FILE_PREFIX_WITHOUT_USER as LOG_PREFIX, @@ -312,12 +312,66 @@ def copy_local_to_tmpdir(local_path: str, verbose: bool = False) -> str: return temp_dir +def _build_kernel_compile_mapping( + raw_log_dir: str, + torch_trace_dir: Optional[str] = None, +) -> Optional[Dict[str, Any]]: + """ + Build kernel compile mapping from inductor's torch trace logs. + + Searches for torch trace log files and parses them to extract + kernel_source_path -> CompileInfo mappings. These mappings allow + attribution of Triton kernels to their originating compilation frame + when pt_info is missing (multi-process Triton JIT scenarios). + + Args: + raw_log_dir: Directory containing tritonparse logs (used for auto-discovery). + torch_trace_dir: Explicit directory containing torch trace logs. + If None, auto-discovers in raw_log_dir. + + Returns: + Dict mapping kernel source paths to CompileInfo, or None if no logs found. + """ + from .torch_trace_parser import discover_torch_trace_files, parse_torch_trace_logs + + # Determine where to look for torch trace logs + search_dirs = [] + if torch_trace_dir: + search_dirs.append(torch_trace_dir) + # Also check the raw log directory (torch trace logs may coexist) + search_dirs.append(raw_log_dir) + + all_log_paths: List[str] = [] + seen_paths: set = set() + for search_dir in search_dirs: + if not os.path.isdir(search_dir): + continue + torch_files = discover_torch_trace_files(search_dir) + for rank_files in torch_files.values(): + for path in rank_files: + if path not in seen_paths: + all_log_paths.append(path) + seen_paths.add(path) + + if not all_log_paths: + return None + + mapping = parse_torch_trace_logs(all_log_paths) + if mapping: + logger.info( + f"Built kernel compile mapping with {len(mapping)} entries " + f"from {len(all_log_paths)} torch trace log(s)" + ) + return mapping if mapping else None + + def parse_logs( logs_to_parse: str, rank_config: RankConfig, verbose: bool = False, tritonparse_url_prefix: str = "", split_inductor_compilations: bool = True, + torch_trace_dir: Optional[str] = None, ) -> Tuple[str, dict]: """ Parse logs. @@ -330,6 +384,10 @@ def parse_logs( split_inductor_compilations: Whether to split output files by frame_id, compile_id, attempt_id, and compiled_autograd_id. Defaults to True. This rule follows tlparse's behavior. + torch_trace_dir: Optional path to directory containing inductor torch trace + logs. When provided, kernel compilation attribution will use these logs to + recover frame_id/compile_id for kernels compiled in multi-process scenarios. + If None, auto-discovers torch trace files in the same directory as tritonparse logs. Returns: Tuple of (parsed log directory, file mapping) """ @@ -372,6 +430,10 @@ def parse_logs( ranks[Rank(Rank.NO_RANK)].append(path) if not ranks: raise RuntimeError(f"No eligible structured trace logs found in {raw_log_dir}") + + # Build kernel compile mapping from torch trace logs (if available) + kernel_compile_mapping = _build_kernel_compile_mapping(raw_log_dir, torch_trace_dir) + file_mapping = {"tritonparse_url_prefix": tritonparse_url_prefix} # Parse each eligible log for rank, files in ranks.items(): @@ -406,7 +468,12 @@ def parse_logs( relative_path = "" if rank.is_no_rank else rank.to_string("") output_dir = os.path.join(parsed_log_dir, relative_path) # Parse the file - parse_single_file(input_file, output_dir, split_inductor_compilations) + parse_single_file( + input_file, + output_dir, + split_inductor_compilations, + kernel_compile_mapping=kernel_compile_mapping, + ) # Collect generated files after parsing and gzip them immediately if os.path.exists(output_dir): generated_files = [] diff --git a/tritonparse/parse/torch_trace_parser.py b/tritonparse/parse/torch_trace_parser.py new file mode 100644 index 0000000..ebbec49 --- /dev/null +++ b/tritonparse/parse/torch_trace_parser.py @@ -0,0 +1,212 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +""" +Parser for inductor's torch trace logs. + +Extracts kernel_source_path -> CompileInfo mappings from inductor_output_code +events in torch trace log files. These mappings can be used to attribute Triton +kernels to their originating PyTorch compilation frame when pt_info is missing +(e.g., in multi-process Triton JIT compilation scenarios). +""" + +import json +import os +import re +from dataclasses import dataclass +from typing import Dict, List, Optional + +from tritonparse.tp_logger import get_logger + +logger = get_logger("TorchTraceParser") + +# Pattern to extract kernel path from output_code payload +KERNEL_PATH_PATTERN = re.compile(r"^# kernel path: (.+)$", re.MULTILINE) + + +@dataclass +class CompileInfo: + """Compilation frame info extracted from inductor's torch trace log.""" + + frame_id: Optional[int] = None + frame_compile_id: Optional[int] = None + attempt: int = 0 + compiled_autograd_id: Optional[int] = None + + +def _extract_json_from_glog_line(line: str) -> Optional[str]: + """ + Extract JSON string from a glog-formatted line. + + Glog format: V{timestamp} {pid} {filepath}:{lineno}] {json_metadata} + + Returns the JSON string portion, or None if the line doesn't match. + """ + idx = line.find("] ") + if idx == -1: + return None + return line[idx + 2 :] + + +def _parse_torch_trace_log(log_path: str) -> Dict[str, CompileInfo]: + """ + Parse a single torch trace log file and extract kernel_source_path -> CompileInfo mappings. + + The torch trace log format is: + - Each record starts with a glog prefix line containing JSON metadata + - Subsequent lines starting with \\t are the payload (continuation of the record) + + For inductor_output_code events: + - The JSON metadata contains frame_id, frame_compile_id, attempt, compiled_autograd_id + - The payload contains the output_code.py content, which has '# kernel path: ...' comments + + Args: + log_path: Path to the torch trace log file. + + Returns: + Dict mapping kernel_source_path (absolute path) to CompileInfo. + """ + mapping: Dict[str, CompileInfo] = {} + + current_compile_info: Optional[CompileInfo] = None + current_payload_lines: List[str] = [] + in_output_code_event = False + + def _flush_current_event() -> None: + """Process the accumulated payload for the current inductor_output_code event.""" + nonlocal current_compile_info, current_payload_lines, in_output_code_event + if not in_output_code_event or current_compile_info is None: + in_output_code_event = False + current_compile_info = None + current_payload_lines = [] + return + + payload_text = "\n".join(current_payload_lines) + kernel_paths = KERNEL_PATH_PATTERN.findall(payload_text) + for kp in kernel_paths: + kp = kp.strip() + if kp: + mapping[kp] = current_compile_info + logger.debug( + f"Mapped kernel path {kp} -> frame_id={current_compile_info.frame_id}, " + f"frame_compile_id={current_compile_info.frame_compile_id}" + ) + + in_output_code_event = False + current_compile_info = None + current_payload_lines = [] + + try: + with open(log_path, "r", errors="replace") as f: + for line in f: + line = line.rstrip("\n") + + # Continuation line (tab-indented payload) + if line.startswith("\t"): + if in_output_code_event: + # Strip the leading tab + current_payload_lines.append(line[1:]) + continue + + # New record — flush the previous event first + _flush_current_event() + + # Try to parse this as a new glog record with JSON metadata + json_str = _extract_json_from_glog_line(line) + if not json_str: + continue + + try: + metadata = json.loads(json_str) + except (json.JSONDecodeError, ValueError): + continue + + # Check if this is an inductor_output_code event + if not isinstance(metadata, dict): + continue + if "inductor_output_code" not in metadata: + continue + + # Extract compile info from the metadata + in_output_code_event = True + current_compile_info = CompileInfo( + frame_id=metadata.get("frame_id"), + frame_compile_id=metadata.get("frame_compile_id"), + attempt=metadata.get("attempt", 0), + compiled_autograd_id=metadata.get("compiled_autograd_id"), + ) + + # Flush the last event + _flush_current_event() + + except OSError as e: + logger.warning(f"Failed to read torch trace log {log_path}: {e}") + + return mapping + + +def parse_torch_trace_logs( + log_paths: List[str], +) -> Dict[str, CompileInfo]: + """ + Parse multiple torch trace log files and merge their mappings. + + Args: + log_paths: List of paths to torch trace log files. + + Returns: + Merged dict mapping kernel_source_path to CompileInfo. + """ + merged: Dict[str, CompileInfo] = {} + for path in log_paths: + logger.info(f"Parsing torch trace log: {path}") + file_mapping = _parse_torch_trace_log(path) + logger.info(f"Extracted {len(file_mapping)} kernel path mappings from {path}") + merged.update(file_mapping) + return merged + + +# Prefix used by torch's structured trace logging +TORCH_TRACE_PREFIX = "dedicated_log_torch_trace_" + + +def discover_torch_trace_files( + search_dir: str, +) -> Dict[Optional[int], List[str]]: + """ + Discover torch trace log files in a directory, grouped by rank. + + Args: + search_dir: Directory to search for torch trace log files. + + Returns: + Dict mapping rank (int or None for no-rank files) to list of file paths. + """ + rank_pattern = re.compile(r"rank_(\d+)_") + result: Dict[Optional[int], List[str]] = {} + + try: + for item in os.listdir(search_dir): + if TORCH_TRACE_PREFIX not in item: + continue + if not item.endswith(".log"): + continue + full_path = os.path.join(search_dir, item) + if not os.path.isfile(full_path): + continue + + rank_match = rank_pattern.search(item) + rank = int(rank_match.group(1)) if rank_match else None + result.setdefault(rank, []).append(full_path) + except OSError as e: + logger.warning( + f"Failed to scan directory {search_dir} for torch trace logs: {e}" + ) + + if result: + total_files = sum(len(v) for v in result.values()) + logger.info( + f"Discovered {total_files} torch trace log file(s) across " + f"{len(result)} rank(s) in {search_dir}" + ) + + return result diff --git a/tritonparse/parse/trace_processor.py b/tritonparse/parse/trace_processor.py index 0436120..4e0c2c3 100644 --- a/tritonparse/parse/trace_processor.py +++ b/tritonparse/parse/trace_processor.py @@ -3,7 +3,7 @@ import json import os from collections import defaultdict -from typing import Any, Dict, List, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple from tritonparse.tools.compression import open_compressed_file from tritonparse.tp_logger import get_logger @@ -333,10 +333,102 @@ def parse_single_trace_content(trace_content: str) -> str: return json.dumps(entry, separators=(",", ":")) + "\n" +def _resolve_compile_info( + event: Dict[str, Any], + kernel_compile_mapping: Dict[str, Any], +) -> Optional[Any]: + """ + Resolve CompileInfo for a compilation event using kernel_compile_mapping. + + Attempts to find the kernel's source path from the event and look it up + in the mapping to recover frame_id/compile_id when pt_info is missing. + + Resolution order: + 1. python_source.file_path (most reliable, available even in multi-process) + 2. Stack trace scanning for torchinductor paths (fallback for fake compilations) + + Args: + event: A compilation event dict. + kernel_compile_mapping: Mapping from kernel_source_path to CompileInfo. + + Returns: + CompileInfo if found, None otherwise. + """ + # Try python_source.file_path first (direct and reliable) + payload = event.get("payload", {}) + python_source = payload.get("python_source", {}) + kernel_path = python_source.get("file_path") + if kernel_path and kernel_path in kernel_compile_mapping: + return kernel_compile_mapping[kernel_path] + + # Fallback: scan stack trace for torchinductor-generated file paths + stack = event.get("stack", []) + for frame in stack: + filename = frame.get("filename", "") + if "torchinductor" in filename and filename.endswith(".py"): + if filename in kernel_compile_mapping: + return kernel_compile_mapping[filename] + + return None + + +def _determine_output_fname( + pt_info: Dict[str, Any], + file_name_without_extension: str, + split_inductor_compilations: bool, + event: Optional[Dict[str, Any]] = None, + kernel_compile_mapping: Optional[Dict[str, Any]] = None, +) -> str: + """ + Determine the output filename for a compilation event. + + When pt_info contains frame_id/frame_compile_id, uses those directly. + When pt_info is missing but kernel_compile_mapping is available, + attempts to resolve via python_source or stack trace. + + Args: + pt_info: The pt_info dict from the compilation payload. + file_name_without_extension: Base name for the default mapped file. + split_inductor_compilations: Whether splitting is enabled. + event: The full compilation event (used for mapping resolution). + kernel_compile_mapping: Optional mapping from kernel paths to CompileInfo. + + Returns: + Output filename string (without directory). + """ + if not split_inductor_compilations: + return f"{file_name_without_extension}_mapped.ndjson" + + frame_id = pt_info.get("frame_id") + frame_compile_id = pt_info.get("frame_compile_id") + attempt_id = pt_info.get("attempt_id", 0) + cai = pt_info.get("compiled_autograd_id", "-") + + # Try to resolve via mapping when pt_info is missing + if frame_id is None and frame_compile_id is None: + if event is not None and kernel_compile_mapping: + resolved = _resolve_compile_info(event, kernel_compile_mapping) + if resolved is not None: + frame_id = resolved.frame_id + frame_compile_id = resolved.frame_compile_id + attempt_id = resolved.attempt + cai = ( + resolved.compiled_autograd_id + if resolved.compiled_autograd_id is not None + else "-" + ) + + if frame_id is not None or frame_compile_id is not None: + return f"f{frame_id}_fc{frame_compile_id}_a{attempt_id}_cai{cai}.ndjson" + else: + return f"{file_name_without_extension}_mapped.ndjson" + + def parse_single_file( file_path: str, output_dir: str = None, split_inductor_compilations: bool = True, + kernel_compile_mapping: Optional[Dict[str, Any]] = None, ): """ Process a single file, correctly group events by kernel, and extract mappings. @@ -351,6 +443,9 @@ def parse_single_file( split_inductor_compilations (bool, optional): Whether to split output files by frame_id, compile_id, attempt_id, and compiled_autograd_id. Defaults to True. This rule follows tlparse's behavior. + kernel_compile_mapping (dict, optional): Mapping from kernel source paths + to CompileInfo objects. Used to recover frame_id/compile_id for kernels + whose pt_info is missing (e.g., multi-process Triton JIT compilation). """ # ===================================================== # Pass 1: Pre-scan to identify kernels needing fake compilations @@ -408,8 +503,14 @@ def parse_single_file( for fake_comp in fake_compilations: kernel_hash = fake_comp["payload"]["metadata"]["hash"] - # Set output_file (use default file name since pt_info is not available) - fname = f"{file_name_without_extension}_mapped.ndjson" + # Determine output file — try mapping resolution for fake compilations too + fname = _determine_output_fname( + pt_info={}, + file_name_without_extension=file_name_without_extension, + split_inductor_compilations=split_inductor_compilations, + event=fake_comp, + kernel_compile_mapping=kernel_compile_mapping, + ) output_file = os.path.join(output_dir, fname) # Store in kernels_by_hash (without occurrence_id for now) @@ -463,18 +564,13 @@ def parse_single_file( # Split inductor compilations into separate files # This rule follows tlparse's behavior. - if split_inductor_compilations: - pt_info = payload.get("pt_info", {}) - frame_id = pt_info.get("frame_id") - frame_compile_id = pt_info.get("frame_compile_id") - attempt_id = pt_info.get("attempt_id", 0) - cai = pt_info.get("compiled_autograd_id", "-") - if frame_id is not None or frame_compile_id is not None: - fname = f"f{frame_id}_fc{frame_compile_id}_a{attempt_id}_cai{cai}.ndjson" - else: - fname = f"{file_name_without_extension}_mapped.ndjson" - else: - fname = f"{file_name_without_extension}_mapped.ndjson" + fname = _determine_output_fname( + pt_info=payload.get("pt_info", {}), + file_name_without_extension=file_name_without_extension, + split_inductor_compilations=split_inductor_compilations, + event=parsed_json, + kernel_compile_mapping=kernel_compile_mapping, + ) output_file = os.path.join(output_dir, fname) # The full processing is deferred until the final write. diff --git a/tritonparse/parse/utils.py b/tritonparse/parse/utils.py index e4d95fa..7119dae 100644 --- a/tritonparse/parse/utils.py +++ b/tritonparse/parse/utils.py @@ -76,6 +76,16 @@ def _add_parse_args(parser: argparse.ArgumentParser) -> None: action="store_true", ) parser.add_argument("-v", "--verbose", help="Verbose logging", action="store_true") + parser.add_argument( + "--torch-trace-dir", + type=str, + default=None, + help=( + "Path to directory containing inductor torch trace logs. " + "Used to recover kernel compilation attribution in multi-process scenarios. " + "If not specified, auto-discovers torch trace files alongside tritonparse logs." + ), + ) if is_fbcode(): from tritonparse.fb.utils import append_parser @@ -91,6 +101,7 @@ def oss_run( verbose: bool = False, split_inductor_compilations: bool = True, skip_logger: bool = True, + torch_trace_dir: Optional[str] = None, ): """ Main function for tritonparse. It is for OSS only. @@ -103,6 +114,7 @@ def oss_run( all_ranks: Analyze all ranks verbose: Verbose logging skip_logger: Unused in OSS, kept for API compatibility. + torch_trace_dir: Path to directory containing inductor torch trace logs. """ source = Source(source, verbose) rank_config = RankConfig.from_cli_args(rank, all_ranks, source.type) @@ -137,6 +149,7 @@ def oss_run( rank_config, verbose, split_inductor_compilations=split_inductor_compilations, + torch_trace_dir=torch_trace_dir, ) else: parsed_log_dir = source.value @@ -161,6 +174,7 @@ def unified_parse( verbose: bool = False, split_inductor_compilations: bool = True, skip_logger: bool = False, + torch_trace_dir: Optional[str] = None, **kwargs, ): """ @@ -174,6 +188,7 @@ def unified_parse( all_ranks: Whether to analyze all ranks verbose: Whether to enable verbose logging skip_logger: Whether to skip usage logging (default: False). + torch_trace_dir: Path to directory containing inductor torch trace logs. """ # Log usage for API invocations if not skip_logger and is_fbcode(): @@ -196,6 +211,7 @@ def unified_parse( verbose=verbose, split_inductor_compilations=split_inductor_compilations, skip_logger=skip_logger, + torch_trace_dir=torch_trace_dir, **kwargs, ) return output