|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +"""Tests for kernel attribution via compile mapping in trace_processor.""" |
| 3 | + |
| 4 | +import json |
| 5 | +import os |
| 6 | +import tempfile |
| 7 | +import unittest |
| 8 | + |
| 9 | +from tritonparse.parse.torch_trace_parser import CompileInfo |
| 10 | +from tritonparse.parse.trace_processor import ( |
| 11 | + _determine_output_fname, |
| 12 | + _resolve_compile_info, |
| 13 | + parse_single_file, |
| 14 | +) |
| 15 | + |
| 16 | + |
| 17 | +class TestResolveCompileInfo(unittest.TestCase): |
| 18 | + """Tests for _resolve_compile_info.""" |
| 19 | + |
| 20 | + def _make_mapping(self): |
| 21 | + return { |
| 22 | + "/tmp/torchinductor_user/ab/kernel1.py": CompileInfo( |
| 23 | + frame_id=0, frame_compile_id=0 |
| 24 | + ), |
| 25 | + "/tmp/torchinductor_user/cd/kernel2.py": CompileInfo( |
| 26 | + frame_id=1, frame_compile_id=0, attempt=1 |
| 27 | + ), |
| 28 | + } |
| 29 | + |
| 30 | + def test_resolve_via_python_source(self): |
| 31 | + """Test resolution via python_source.file_path.""" |
| 32 | + event = { |
| 33 | + "payload": { |
| 34 | + "python_source": {"file_path": "/tmp/torchinductor_user/ab/kernel1.py"} |
| 35 | + }, |
| 36 | + "stack": [], |
| 37 | + } |
| 38 | + result = _resolve_compile_info(event, self._make_mapping()) |
| 39 | + self.assertIsNotNone(result) |
| 40 | + self.assertEqual(result.frame_id, 0) |
| 41 | + self.assertEqual(result.frame_compile_id, 0) |
| 42 | + |
| 43 | + def test_resolve_via_stack_trace(self): |
| 44 | + """Test resolution via stack trace when python_source is missing.""" |
| 45 | + event = { |
| 46 | + "payload": {}, |
| 47 | + "stack": [ |
| 48 | + {"filename": "/user/code.py", "line": 10, "name": "main"}, |
| 49 | + { |
| 50 | + "filename": "/tmp/torchinductor_user/cd/kernel2.py", |
| 51 | + "line": 1, |
| 52 | + "name": "kernel", |
| 53 | + }, |
| 54 | + {"filename": "triton/jit.py", "line": 50, "name": "run"}, |
| 55 | + ], |
| 56 | + } |
| 57 | + result = _resolve_compile_info(event, self._make_mapping()) |
| 58 | + self.assertIsNotNone(result) |
| 59 | + self.assertEqual(result.frame_id, 1) |
| 60 | + self.assertEqual(result.frame_compile_id, 0) |
| 61 | + self.assertEqual(result.attempt, 1) |
| 62 | + |
| 63 | + def test_no_match(self): |
| 64 | + """Test that None is returned when no match is found.""" |
| 65 | + event = { |
| 66 | + "payload": {"python_source": {"file_path": "/tmp/unknown/path.py"}}, |
| 67 | + "stack": [{"filename": "/user/code.py", "line": 10, "name": "main"}], |
| 68 | + } |
| 69 | + result = _resolve_compile_info(event, self._make_mapping()) |
| 70 | + self.assertIsNone(result) |
| 71 | + |
| 72 | + def test_empty_event(self): |
| 73 | + """Test with minimal event data.""" |
| 74 | + event = {} |
| 75 | + result = _resolve_compile_info(event, self._make_mapping()) |
| 76 | + self.assertIsNone(result) |
| 77 | + |
| 78 | + def test_python_source_takes_priority(self): |
| 79 | + """Test that python_source.file_path is preferred over stack trace.""" |
| 80 | + event = { |
| 81 | + "payload": { |
| 82 | + "python_source": {"file_path": "/tmp/torchinductor_user/ab/kernel1.py"} |
| 83 | + }, |
| 84 | + "stack": [ |
| 85 | + { |
| 86 | + "filename": "/tmp/torchinductor_user/cd/kernel2.py", |
| 87 | + "line": 1, |
| 88 | + "name": "kernel", |
| 89 | + } |
| 90 | + ], |
| 91 | + } |
| 92 | + result = _resolve_compile_info(event, self._make_mapping()) |
| 93 | + # Should use python_source path (kernel1 -> frame_id=0), not stack (kernel2 -> frame_id=1) |
| 94 | + self.assertEqual(result.frame_id, 0) |
| 95 | + |
| 96 | + |
| 97 | +class TestDetermineOutputFname(unittest.TestCase): |
| 98 | + """Tests for _determine_output_fname.""" |
| 99 | + |
| 100 | + def test_with_pt_info(self): |
| 101 | + """Test normal case where pt_info has frame_id/compile_id.""" |
| 102 | + fname = _determine_output_fname( |
| 103 | + pt_info={"frame_id": 0, "frame_compile_id": 1, "attempt_id": 0}, |
| 104 | + file_name_without_extension="trace", |
| 105 | + split_inductor_compilations=True, |
| 106 | + ) |
| 107 | + self.assertEqual(fname, "f0_fc1_a0_cai-.ndjson") |
| 108 | + |
| 109 | + def test_without_pt_info_no_mapping(self): |
| 110 | + """Test fallback to mapped file when pt_info is missing and no mapping.""" |
| 111 | + fname = _determine_output_fname( |
| 112 | + pt_info={}, |
| 113 | + file_name_without_extension="trace", |
| 114 | + split_inductor_compilations=True, |
| 115 | + ) |
| 116 | + self.assertEqual(fname, "trace_mapped.ndjson") |
| 117 | + |
| 118 | + def test_without_pt_info_with_mapping(self): |
| 119 | + """Test resolution via mapping when pt_info is missing.""" |
| 120 | + mapping = { |
| 121 | + "/tmp/torchinductor_user/ab/kernel.py": CompileInfo( |
| 122 | + frame_id=3, frame_compile_id=2, attempt=1, compiled_autograd_id=5 |
| 123 | + ) |
| 124 | + } |
| 125 | + event = { |
| 126 | + "payload": { |
| 127 | + "python_source": {"file_path": "/tmp/torchinductor_user/ab/kernel.py"} |
| 128 | + }, |
| 129 | + "stack": [], |
| 130 | + } |
| 131 | + fname = _determine_output_fname( |
| 132 | + pt_info={}, |
| 133 | + file_name_without_extension="trace", |
| 134 | + split_inductor_compilations=True, |
| 135 | + event=event, |
| 136 | + kernel_compile_mapping=mapping, |
| 137 | + ) |
| 138 | + self.assertEqual(fname, "f3_fc2_a1_cai5.ndjson") |
| 139 | + |
| 140 | + def test_split_disabled(self): |
| 141 | + """Test that splitting disabled always returns mapped filename.""" |
| 142 | + fname = _determine_output_fname( |
| 143 | + pt_info={"frame_id": 0, "frame_compile_id": 0}, |
| 144 | + file_name_without_extension="trace", |
| 145 | + split_inductor_compilations=False, |
| 146 | + ) |
| 147 | + self.assertEqual(fname, "trace_mapped.ndjson") |
| 148 | + |
| 149 | + def test_compiled_autograd_id_none(self): |
| 150 | + """Test that compiled_autograd_id defaults to '-' when not set.""" |
| 151 | + fname = _determine_output_fname( |
| 152 | + pt_info={"frame_id": 0, "frame_compile_id": 0}, |
| 153 | + file_name_without_extension="trace", |
| 154 | + split_inductor_compilations=True, |
| 155 | + ) |
| 156 | + self.assertEqual(fname, "f0_fc0_a0_cai-.ndjson") |
| 157 | + |
| 158 | + |
| 159 | +class TestParseSingleFileWithMapping(unittest.TestCase): |
| 160 | + """Integration tests for parse_single_file with kernel_compile_mapping.""" |
| 161 | + |
| 162 | + def test_mapping_redirects_compilation_to_frame_file(self): |
| 163 | + """Test that a compilation without pt_info is redirected when mapping is provided.""" |
| 164 | + kernel_path = "/tmp/torchinductor_user/ab/cabcdef1234.py" |
| 165 | + mapping = { |
| 166 | + kernel_path: CompileInfo(frame_id=2, frame_compile_id=1), |
| 167 | + } |
| 168 | + |
| 169 | + trace_lines = [ |
| 170 | + json.dumps( |
| 171 | + { |
| 172 | + "event_type": "compilation", |
| 173 | + "pid": 1000, |
| 174 | + "stack": [], |
| 175 | + "payload": { |
| 176 | + "metadata": {"hash": "kernel_hash_1", "name": "kernel_1"}, |
| 177 | + "file_content": {}, |
| 178 | + "file_path": {}, |
| 179 | + "python_source": {"file_path": kernel_path}, |
| 180 | + # No pt_info — this is the multi-process scenario |
| 181 | + }, |
| 182 | + } |
| 183 | + ), |
| 184 | + json.dumps( |
| 185 | + { |
| 186 | + "event_type": "launch", |
| 187 | + "name": "kernel_1", |
| 188 | + "pid": 1000, |
| 189 | + "stack": [], |
| 190 | + "compilation_metadata": {"hash": "kernel_hash_1"}, |
| 191 | + } |
| 192 | + ), |
| 193 | + ] |
| 194 | + |
| 195 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 196 | + input_file = os.path.join(temp_dir, "test_trace.ndjson") |
| 197 | + with open(input_file, "w") as f: |
| 198 | + for line in trace_lines: |
| 199 | + f.write(line + "\n") |
| 200 | + |
| 201 | + output_dir = os.path.join(temp_dir, "output") |
| 202 | + os.makedirs(output_dir) |
| 203 | + |
| 204 | + parse_single_file(input_file, output_dir, kernel_compile_mapping=mapping) |
| 205 | + |
| 206 | + output_files = os.listdir(output_dir) |
| 207 | + # Should produce a frame-specific file, not _mapped |
| 208 | + frame_files = [f for f in output_files if f.startswith("f")] |
| 209 | + mapped_files = [f for f in output_files if "mapped" in f] |
| 210 | + self.assertEqual(len(frame_files), 1) |
| 211 | + self.assertEqual(len(mapped_files), 0) |
| 212 | + self.assertEqual(frame_files[0], "f2_fc1_a0_cai-.ndjson") |
| 213 | + |
| 214 | + def test_no_mapping_falls_back_to_mapped(self): |
| 215 | + """Test that without mapping, compilations without pt_info go to _mapped.""" |
| 216 | + trace_lines = [ |
| 217 | + json.dumps( |
| 218 | + { |
| 219 | + "event_type": "compilation", |
| 220 | + "pid": 1000, |
| 221 | + "stack": [], |
| 222 | + "payload": { |
| 223 | + "metadata": {"hash": "kernel_hash_1", "name": "kernel_1"}, |
| 224 | + "file_content": {}, |
| 225 | + "file_path": {}, |
| 226 | + # No pt_info, no python_source |
| 227 | + }, |
| 228 | + } |
| 229 | + ), |
| 230 | + json.dumps( |
| 231 | + { |
| 232 | + "event_type": "launch", |
| 233 | + "name": "kernel_1", |
| 234 | + "pid": 1000, |
| 235 | + "stack": [], |
| 236 | + "compilation_metadata": {"hash": "kernel_hash_1"}, |
| 237 | + } |
| 238 | + ), |
| 239 | + ] |
| 240 | + |
| 241 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 242 | + input_file = os.path.join(temp_dir, "test_trace.ndjson") |
| 243 | + with open(input_file, "w") as f: |
| 244 | + for line in trace_lines: |
| 245 | + f.write(line + "\n") |
| 246 | + |
| 247 | + output_dir = os.path.join(temp_dir, "output") |
| 248 | + os.makedirs(output_dir) |
| 249 | + |
| 250 | + parse_single_file(input_file, output_dir) |
| 251 | + |
| 252 | + output_files = os.listdir(output_dir) |
| 253 | + mapped_files = [f for f in output_files if "mapped" in f] |
| 254 | + self.assertGreater(len(mapped_files), 0) |
| 255 | + |
| 256 | + def test_mixed_with_and_without_pt_info(self): |
| 257 | + """Test a mix of events: some with pt_info, some resolved via mapping.""" |
| 258 | + kernel_path_a = "/tmp/torchinductor_user/ab/kernel_a.py" |
| 259 | + mapping = { |
| 260 | + kernel_path_a: CompileInfo(frame_id=0, frame_compile_id=0), |
| 261 | + } |
| 262 | + |
| 263 | + trace_lines = [ |
| 264 | + # Compilation WITH pt_info (should be split normally) |
| 265 | + json.dumps( |
| 266 | + { |
| 267 | + "event_type": "compilation", |
| 268 | + "pid": 1000, |
| 269 | + "stack": [], |
| 270 | + "payload": { |
| 271 | + "metadata": {"hash": "hash_with_pt", "name": "kernel_with_pt"}, |
| 272 | + "file_content": {}, |
| 273 | + "file_path": {}, |
| 274 | + "pt_info": { |
| 275 | + "frame_id": 1, |
| 276 | + "frame_compile_id": 0, |
| 277 | + }, |
| 278 | + }, |
| 279 | + } |
| 280 | + ), |
| 281 | + json.dumps( |
| 282 | + { |
| 283 | + "event_type": "launch", |
| 284 | + "name": "kernel_with_pt", |
| 285 | + "pid": 1000, |
| 286 | + "stack": [], |
| 287 | + "compilation_metadata": {"hash": "hash_with_pt"}, |
| 288 | + } |
| 289 | + ), |
| 290 | + # Compilation WITHOUT pt_info (should be resolved via mapping) |
| 291 | + json.dumps( |
| 292 | + { |
| 293 | + "event_type": "compilation", |
| 294 | + "pid": 1000, |
| 295 | + "stack": [], |
| 296 | + "payload": { |
| 297 | + "metadata": { |
| 298 | + "hash": "hash_without_pt", |
| 299 | + "name": "kernel_without_pt", |
| 300 | + }, |
| 301 | + "file_content": {}, |
| 302 | + "file_path": {}, |
| 303 | + "python_source": {"file_path": kernel_path_a}, |
| 304 | + }, |
| 305 | + } |
| 306 | + ), |
| 307 | + json.dumps( |
| 308 | + { |
| 309 | + "event_type": "launch", |
| 310 | + "name": "kernel_without_pt", |
| 311 | + "pid": 1000, |
| 312 | + "stack": [], |
| 313 | + "compilation_metadata": {"hash": "hash_without_pt"}, |
| 314 | + } |
| 315 | + ), |
| 316 | + ] |
| 317 | + |
| 318 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 319 | + input_file = os.path.join(temp_dir, "test_trace.ndjson") |
| 320 | + with open(input_file, "w") as f: |
| 321 | + for line in trace_lines: |
| 322 | + f.write(line + "\n") |
| 323 | + |
| 324 | + output_dir = os.path.join(temp_dir, "output") |
| 325 | + os.makedirs(output_dir) |
| 326 | + |
| 327 | + parse_single_file(input_file, output_dir, kernel_compile_mapping=mapping) |
| 328 | + |
| 329 | + output_files = sorted(os.listdir(output_dir)) |
| 330 | + # Should have two frame files: f0_fc0 and f1_fc0 |
| 331 | + frame_files = sorted([f for f in output_files if f.startswith("f")]) |
| 332 | + self.assertEqual(len(frame_files), 2) |
| 333 | + self.assertIn("f0_fc0_a0_cai-.ndjson", frame_files) |
| 334 | + self.assertIn("f1_fc0_a0_cai-.ndjson", frame_files) |
| 335 | + |
| 336 | + def test_fake_compilation_with_mapping(self): |
| 337 | + """Test that fake compilations can also be attributed via stack trace mapping.""" |
| 338 | + kernel_path = "/tmp/torchinductor_user/ab/kernel_fake.py" |
| 339 | + mapping = { |
| 340 | + kernel_path: CompileInfo(frame_id=5, frame_compile_id=0), |
| 341 | + } |
| 342 | + |
| 343 | + trace_lines = [ |
| 344 | + # Only launch event (will trigger fake compilation) |
| 345 | + json.dumps( |
| 346 | + { |
| 347 | + "event_type": "launch", |
| 348 | + "name": "fake_kernel", |
| 349 | + "pid": 1000, |
| 350 | + "stack": [ |
| 351 | + {"filename": "/user/code.py", "line": 10, "name": "main"}, |
| 352 | + { |
| 353 | + "filename": kernel_path, |
| 354 | + "line": 1, |
| 355 | + "name": "kernel_fn", |
| 356 | + }, |
| 357 | + ], |
| 358 | + "compilation_metadata": { |
| 359 | + "hash": "fake_hash", |
| 360 | + "name": "fake_kernel", |
| 361 | + "num_warps": 4, |
| 362 | + }, |
| 363 | + } |
| 364 | + ), |
| 365 | + ] |
| 366 | + |
| 367 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 368 | + input_file = os.path.join(temp_dir, "test_trace.ndjson") |
| 369 | + with open(input_file, "w") as f: |
| 370 | + for line in trace_lines: |
| 371 | + f.write(line + "\n") |
| 372 | + |
| 373 | + output_dir = os.path.join(temp_dir, "output") |
| 374 | + os.makedirs(output_dir) |
| 375 | + |
| 376 | + parse_single_file(input_file, output_dir, kernel_compile_mapping=mapping) |
| 377 | + |
| 378 | + output_files = os.listdir(output_dir) |
| 379 | + frame_files = [f for f in output_files if f.startswith("f")] |
| 380 | + self.assertEqual(len(frame_files), 1) |
| 381 | + self.assertEqual(frame_files[0], "f5_fc0_a0_cai-.ndjson") |
| 382 | + |
| 383 | + |
| 384 | +if __name__ == "__main__": |
| 385 | + unittest.main() |
0 commit comments