Skip to content

Commit 5d05cb9

Browse files
authored
Exclude L2 cache clear time from timing measurement (#527)
1 parent 0bb2da4 commit 5d05cb9

File tree

1 file changed

+25
-3
lines changed
  • tritonbench/components/do_bench

1 file changed

+25
-3
lines changed

tritonbench/components/do_bench/run.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,18 +205,40 @@ def _do_bench_cudagraph_with_cache_clear(
205205
fn()
206206
torch.cuda.synchronize()
207207

208-
ret = []
208+
cache_clear_graph = torch.cuda.CUDAGraph()
209+
with torch.cuda.graph(cache_clear_graph):
210+
for _ in range(n_repeat):
211+
cache.zero_()
212+
torch.cuda.synchronize()
213+
209214
n_retries = 10
215+
cache_clear_times = []
216+
total_times = []
210217
for _ in range(n_retries):
218+
cache_clear_start_event = torch.cuda.Event(enable_timing=True)
219+
cache_clear_end_event = torch.cuda.Event(enable_timing=True)
220+
cache_clear_start_event.record()
221+
cache_clear_graph.replay()
222+
cache_clear_end_event.record()
223+
torch.cuda.synchronize()
224+
cache_clear_times.append(
225+
cache_clear_start_event.elapsed_time(cache_clear_end_event) / n_repeat
226+
)
227+
211228
start_event = torch.cuda.Event(enable_timing=True)
212229
end_event = torch.cuda.Event(enable_timing=True)
213230
start_event.record()
214231
g.replay()
215232
end_event.record()
216233
torch.cuda.synchronize()
217-
ret.append(start_event.elapsed_time(end_event) / n_repeat)
234+
total_times.append(start_event.elapsed_time(end_event) / n_repeat)
218235

219-
times = torch.tensor(ret, dtype=torch.float)
236+
all_kernel_times = []
237+
for total_time, cache_clear_time in zip(total_times, cache_clear_times):
238+
kernel_time = total_time - cache_clear_time
239+
all_kernel_times.append(kernel_time)
240+
241+
times = torch.tensor(all_kernel_times, dtype=torch.float)
220242
return _summarize_statistics(times, quantiles, return_mode)
221243

222244

0 commit comments

Comments
 (0)