Skip to content

Commit 379a315

Browse files
authored
Add L2 cache clearing to do_bench_cudagraph, for more realistic timing (#519)
1 parent fe646da commit 379a315

File tree

1 file changed

+55
-1
lines changed
  • tritonbench/components/do_bench

1 file changed

+55
-1
lines changed

tritonbench/components/do_bench/run.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,60 @@ def _do_bench_inductor(fn, warmup, rep, return_mode="all", grad_to_none=None):
166166
return _summarize_statistics(times, quantiles=None, return_mode=return_mode)
167167

168168

169+
def _do_bench_cudagraph_with_cache_clear(
170+
fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"
171+
):
172+
"""Clone of triton.testing.do_bench_cudagraph with explicit L2 cache clearing."""
173+
assert return_mode in ["min", "max", "mean", "median", "all"]
174+
175+
cache = triton.runtime.driver.active.get_empty_cache_for_benchmark()
176+
177+
with torch.cuda.stream(torch.cuda.Stream()):
178+
cache.zero_()
179+
fn()
180+
if grad_to_none is not None:
181+
for x in grad_to_none:
182+
x.detach_()
183+
x.requires_grad_(True)
184+
x.grad = None
185+
186+
start_event = torch.cuda.Event(enable_timing=True)
187+
end_event = torch.cuda.Event(enable_timing=True)
188+
start_event.record()
189+
for _ in range(5):
190+
cache.zero_()
191+
fn()
192+
end_event.record()
193+
torch.cuda.synchronize()
194+
estimate_ms = start_event.elapsed_time(end_event) / 5
195+
196+
n_repeat = 1000 if estimate_ms == 0 else max(1, int(rep / estimate_ms))
197+
198+
g = torch.cuda.CUDAGraph()
199+
with torch.cuda.graph(g):
200+
for _ in range(n_repeat):
201+
if grad_to_none is not None:
202+
for x in grad_to_none:
203+
x.grad = None
204+
cache.zero_()
205+
fn()
206+
torch.cuda.synchronize()
207+
208+
ret = []
209+
n_retries = 10
210+
for _ in range(n_retries):
211+
start_event = torch.cuda.Event(enable_timing=True)
212+
end_event = torch.cuda.Event(enable_timing=True)
213+
start_event.record()
214+
g.replay()
215+
end_event.record()
216+
torch.cuda.synchronize()
217+
ret.append(start_event.elapsed_time(end_event) / n_repeat)
218+
219+
times = torch.tensor(ret, dtype=torch.float)
220+
return _summarize_statistics(times, quantiles, return_mode)
221+
222+
169223
def _do_bench_profiler(
170224
fn, warmup, rep, return_mode="all", grad_to_none=None, use_cudagraph=False
171225
):
@@ -383,7 +437,7 @@ def do_bench_wrapper(
383437
if latency_measure_mode == "profiler":
384438
bench_fn = partial(_do_bench_profiler, warmup=1, use_cudagraph=True)
385439
else:
386-
bench_fn = triton.testing.do_bench_cudagraph
440+
bench_fn = _do_bench_cudagraph_with_cache_clear
387441

388442
return Latency(
389443
times=bench_fn(

0 commit comments

Comments
 (0)