Skip to content

Commit 6ad95ee

Browse files
authored
[AUTOTUNER] A quick follow-up for more device-independent do_bench (#4974)
This is a quick follow-up for the recent autotuner/testing changes as in triton-lang/triton#4496. This PR moves the empty cache creation into the driver code to make the code more device independent.
1 parent a1aa58b commit 6ad95ee

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

python/triton/testing.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
9292
return _summarize_statistics(torch.tensor(ret), quantiles, return_mode)
9393

9494

95-
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device_type="cuda"):
95+
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
9696
"""
9797
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
9898
the 20-th and 80-th performance percentile.
@@ -117,11 +117,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
117117
fn()
118118
di.synchronize()
119119

120-
# We maintain a buffer of 256 MB that we clear
121-
# before each kernel call to make sure that the L2 cache
122-
# doesn't contain any input data before the run
123-
cache_size = 256 * 1024 * 1024
124-
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
120+
cache = runtime.driver.active.get_empty_cache_for_benchmark()
125121

126122
# Estimate the runtime of the function
127123
start_event = di.Event(enable_timing=True)

third_party/amd/backend/driver.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,3 +503,10 @@ def get_current_target(self):
503503
def get_benchmarker(self):
504504
from triton.testing import do_bench
505505
return do_bench
506+
507+
def get_empty_cache_for_benchmark(self):
508+
import torch
509+
510+
# It's the same as the Nvidia backend.
511+
cache_size = 256 * 1024 * 1024
512+
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')

third_party/nvidia/backend/driver.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,3 +452,12 @@ def is_active():
452452
def get_benchmarker(self):
453453
from triton.testing import do_bench
454454
return do_bench
455+
456+
def get_empty_cache_for_benchmark(self):
457+
import torch
458+
459+
# We maintain a buffer of 256 MB that we clear
460+
# before each kernel call to make sure that the L2 cache
461+
# doesn't contain any input data before the run
462+
cache_size = 256 * 1024 * 1024
463+
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')

0 commit comments

Comments
 (0)