Skip to content

Commit c6c6962

Browse files
authored
Profiler mode: fix missing CUDA events (#459)
1 parent 0d8be61 commit c6c6962

File tree

1 file changed

+16
-13
lines changed
  • tritonbench/components/do_bench

1 file changed

+16
-13
lines changed

tritonbench/components/do_bench/run.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -236,17 +236,7 @@ def run_iteration():
236236
"with_stack": False,
237237
}
238238

239-
for _ in range(n_profiler_runs):
240-
# Profile execution
241-
with torch.profiler.profile(**profiler_config) as prof:
242-
if use_cudagraph:
243-
g.replay()
244-
else:
245-
# Execute multiple iterations for regular mode
246-
for _ in range(iterations_per_profiler_run):
247-
run_iteration()
248-
torch.cuda.synchronize()
249-
239+
def _trace_handler(prof: torch.profiler.profile) -> None:
250240
# Collect all kernel execution intervals
251241
kernel_intervals = []
252242

@@ -299,10 +289,23 @@ def run_iteration():
299289
)
300290

301291
# Convert to milliseconds and normalize by iterations
302-
total_kernel_time_ms = (
292+
kernel_time_per_iteration_ms = (
303293
total_kernel_time_us / 1000.0
304294
) / iterations_per_profiler_run
305-
all_kernel_times.append(total_kernel_time_ms)
295+
all_kernel_times.append(kernel_time_per_iteration_ms)
296+
297+
for _ in range(n_profiler_runs):
298+
# Profile execution
299+
with torch.profiler.profile(
300+
**profiler_config, on_trace_ready=_trace_handler
301+
) as prof:
302+
if use_cudagraph:
303+
g.replay()
304+
else:
305+
# Execute multiple iterations for regular mode
306+
for _ in range(iterations_per_profiler_run):
307+
run_iteration()
308+
torch.cuda.synchronize()
306309

307310
times = torch.tensor(all_kernel_times, dtype=torch.float)
308311
return _summarize_statistics(times, quantiles=None, return_mode=return_mode)

0 commit comments

Comments
 (0)