diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 9873a91668..eab0b7ac89 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -308,7 +308,9 @@ def extract_kernels(funcs): raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented") -def get_do_bench(n_warmup: int, n_repeat: int, quantiles: list): +def get_do_bench(n_warmup: int, n_repeat: int, quantiles: list, clear_cache: bool = True): + if clear_cache: + torch.xpu.empty_cache() return functools.partial(do_bench, n_warmup=n_warmup, n_repeat=n_repeat, quantiles=quantiles)