Skip to content

Commit f6a4dda

Browse files
Merge commit '2cc227d6c9604d44b059d8b0ab0c1fefb273a3bb'
2 parents 428a24f + 2cc227d commit f6a4dda

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

python/triton/testing.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
139139
return _summarize_statistics(torch.tensor(ret), quantiles, return_mode)
140140

141141

142-
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
143-
device_type="xpu"):
142+
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device_type="xpu"):
144143
"""
145144
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
146145
the 20-th and 80-th performance percentile.
@@ -155,8 +154,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
155154
:type grad_to_none: torch.tensor, optional
156155
:param quantiles: Performance percentile to return in addition to the median.
157156
:type quantiles: list[float], optional
158-
:param fast_flush: Use faster kernel to flush L2 cache between measurements
159-
:type fast_flush: bool, default is True
160157
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str
161158
"""
162159
assert return_mode in ["min", "max", "mean", "median", "all"]
@@ -171,10 +168,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
171168
# before each kernel call to make sure that the L2 cache
172169
# doesn't contain any input data before the run
173170
cache_size = 256 * 1024 * 1024
174-
if fast_flush:
175-
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device_type)
176-
else:
177-
cache = torch.empty(int(cache_size), dtype=torch.int8, device=device_type)
171+
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device_type)
178172

179173
# Estimate the runtime of the function
180174
start_event = Event(enable_timing=True)

0 commit comments

Comments
 (0)