Skip to content

Commit 839af17

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Error out when reproducing CUDA graph capture kernels (#359)
Summary: Pull Request resolved: #359 Fix #276 When a kernel is launched during CUDA graph capture, argument extraction is skipped (D86722827) and `extracted_args` contains only a `_note` string instead of per-argument dicts. The reproducer's `_create_arg_from_info()` then crashes with `AttributeError: 'str' object has no attribute 'get'`. This change detects the `_note` sentinel in `build_context_bundle()` and raises a clear `RuntimeError` early, before any downstream code tries to process the incomplete argument data. Reviewed By: xuzhao9 Differential Revision: D97514433 fbshipit-source-id: c4e61d50884e0e6e451c581f1d602970247e4849
1 parent e3bae57 commit 839af17

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

tests/cpu/test_placeholder_replacer.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,5 +203,91 @@ def test_scratch_size_none_when_absent(self):
203203
self.assertIsNone(bundle.compile["global_scratch_size"])
204204

205205

206+
class TestBuildContextBundleCudaGraphCapture(unittest.TestCase):
207+
"""Tests that build_context_bundle raises RuntimeError for CUDA graph capture launches.
208+
209+
When a kernel is launched during CUDA graph capture, argument extraction
210+
is skipped (see D86722827) and extracted_args contains only a _note string
211+
instead of per-argument dicts. build_context_bundle should detect this and
212+
raise a clear RuntimeError.
213+
"""
214+
215+
def _make_events_with_note(self, note_message):
216+
"""Create launch + compilation events where extracted_args has a _note sentinel."""
217+
launch_event = {
218+
"event_type": "launch",
219+
"grid": [1, 1, 1],
220+
"extracted_args": {
221+
"_note": note_message,
222+
},
223+
"compilation_metadata": {
224+
"hash": "abc123",
225+
"num_warps": 4,
226+
"num_stages": 2,
227+
},
228+
}
229+
comp_event = {
230+
"event_type": "compilation",
231+
"payload": {
232+
"metadata": {"hash": "abc123", "name": "my_kernel"},
233+
"python_source": {
234+
"file_path": "/tmp/kernel.py",
235+
"code": "@triton.jit\ndef my_kernel(): pass",
236+
},
237+
},
238+
"stack": [],
239+
}
240+
return [launch_event, comp_event]
241+
242+
def test_note_sentinel_raises_runtime_error(self):
243+
"""Test that _note in extracted_args raises RuntimeError."""
244+
from tritonparse.reproducer.ingestion.ndjson import build_context_bundle
245+
246+
note = "Argument extraction skipped during CUDA graph capture"
247+
events = self._make_events_with_note(note)
248+
249+
with self.assertRaises(RuntimeError) as cm:
250+
build_context_bundle(events, line_index=0)
251+
252+
error_msg = str(cm.exception)
253+
self.assertIn("Cannot generate reproducer", error_msg)
254+
self.assertIn("my_kernel", error_msg)
255+
self.assertIn(note, error_msg)
256+
257+
def test_no_note_sentinel_succeeds(self):
258+
"""Test that normal extracted_args without _note works fine."""
259+
from tritonparse.reproducer.ingestion.ndjson import build_context_bundle
260+
261+
launch_event = {
262+
"event_type": "launch",
263+
"grid": [1, 1, 1],
264+
"extracted_args": {
265+
"x": {"type": "tensor", "shape": [4, 4], "dtype": "float32"},
266+
"BLOCK_SIZE": {"type": "int", "value": 128},
267+
},
268+
"compilation_metadata": {
269+
"hash": "abc123",
270+
"num_warps": 4,
271+
"num_stages": 2,
272+
},
273+
}
274+
comp_event = {
275+
"event_type": "compilation",
276+
"payload": {
277+
"metadata": {"hash": "abc123", "name": "my_kernel"},
278+
"python_source": {
279+
"file_path": "/tmp/kernel.py",
280+
"code": "@triton.jit\ndef my_kernel(): pass",
281+
},
282+
},
283+
"stack": [],
284+
}
285+
events = [launch_event, comp_event]
286+
287+
# Should not raise
288+
bundle = build_context_bundle(events, line_index=0)
289+
self.assertEqual(bundle.kernel_info.function_name, "my_kernel")
290+
291+
206292
if __name__ == "__main__":
207293
unittest.main()

tritonparse/reproducer/ingestion/ndjson.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,19 @@ def build_context_bundle(
212212
kernel_info = get_kernel_info(comp_event)
213213
grid = launch_event.get("grid")
214214
extracted_args = launch_event.get("extracted_args", {})
215+
216+
# Check if this launch event was captured during CUDA graph capture,
217+
# which means argument extraction was skipped and we cannot generate
218+
# a reproducer. See D86722827 for context on why extraction is skipped.
219+
if "_note" in extracted_args:
220+
raise RuntimeError(
221+
f"Cannot generate reproducer for kernel "
222+
f"'{kernel_info.function_name}' at line {line_index}: "
223+
f"{extracted_args['_note']}. "
224+
f"Kernel launches during CUDA graph capture do not have "
225+
f"extracted argument data needed for reproducer generation."
226+
)
227+
215228
comp_meta = launch_event.get("compilation_metadata", {})
216229

217230
# Compile metadata subset we care about.

0 commit comments

Comments
 (0)