diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index b289a096e2..37335677b6 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -36,7 +36,7 @@ def _summarize_statistics(times, quantiles, return_mode): return getattr(torch, return_mode)(times).item() -def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", +def do_bench_ipex(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with @@ -44,10 +44,10 @@ def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, ret :param fn: Function to benchmark :type fn: Callable - :param warmup: Warmup time (in ms) - :type warmup: int - :param rep: Repetition time (in ms) - :type rep: int + :param n_warmup: Number of repetitions for warmup + :type n_warmup: int + :param n_repeat: Number of repetitions to collect measurements + :type n_repeat: int :param grad_to_none: Reset the gradient of the provided tensor to None :type grad_to_none: torch.tensor, optional :param quantiles: Performance percentile to return in addition to the median. @@ -69,20 +69,6 @@ def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, ret cache_size = 256 * 1024 * 1024 cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) - # Estimate the runtime of the function - start_event = torch.xpu.Event(enable_timing=True) - end_event = torch.xpu.Event(enable_timing=True) - start_event.record() - for _ in range(5): - cache.zero_() - fn() - end_event.record() - synchronize() - estimate_ms = start_event.elapsed_time(end_event) / 5 - - # compute number of warmup and repeat - n_warmup = max(1, int(warmup / estimate_ms)) - n_repeat = max(1, int(rep / estimate_ms)) # Warm-up for _ in range(n_warmup): fn() @@ -121,18 +107,18 @@ def extract_kernels(funcs): return _summarize_statistics(times, quantiles, return_mode) -def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", - kernel_name=None): # pylint: disable=unused-argument +def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", + device="xpu", kernel_name=None): # pylint: disable=unused-argument """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. :param fn: Function to benchmark :type fn: Callable - :param warmup: Warmup time (in ms) - :type warmup: int - :param rep: Repetition time (in ms) - :type rep: int + :param n_warmup: Number of repetitions for warmup + :type n_warmup: int + :param n_repeat: Number of repetitions to collect measurements + :type n_repeat: int :param grad_to_none: Reset the gradient of the provided tensor to None :type grad_to_none: torch.tensor, optional :param quantiles: Performance percentile to return in addition to the median. @@ -142,24 +128,49 @@ def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=N import torch from triton.testing import do_bench as triton_do_bench - times = triton_do_bench(fn, warmup=warmup, rep=rep, grad_to_none=grad_to_none, return_mode="all", + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) + + # Estimate the runtime of the function + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # The cache is also maintained in `triton_do_bench` function, + # there is no need to duplicate the amount of memory used. + del cache + + # compute warmup and repeat times + warmup_time = n_warmup * estimate_ms + rep_time = n_repeat * estimate_ms + + times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all", device_type=device) times = torch.tensor(times, dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode) -def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", - device="xpu", sync_submitting=True, kernel_name=None): +def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, + return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. :param fn: Function to benchmark :type fn: Callable - :param warmup: Warmup time (in ms) - :type warmup: int - :param rep: Repetition time (in ms) - :type rep: int + :param n_warmup: Number of repetitions for warmup + :type n_warmup: int + :param n_repeat: Number of repetitions to collect measurements + :type n_repeat: int :param grad_to_none: Reset the gradient of the provided tensor to None :type grad_to_none: torch.tensor, optional :param quantiles: Performance percentile to return in addition to the median. @@ -179,20 +190,6 @@ def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None cache_size = 256 * 1024 * 1024 cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) - # Estimate the runtime of the function - start_event = torch.xpu.Event(enable_timing=True) - end_event = torch.xpu.Event(enable_timing=True) - start_event.record() - for _ in range(5): - cache.zero_() - fn() - end_event.record() - synchronize() - estimate_ms = start_event.elapsed_time(end_event) / 5 - - # compute number of warmup and repeat - n_warmup = max(1, int(warmup / estimate_ms)) - n_repeat = max(1, int(rep / estimate_ms)) # Warm-up for _ in range(n_warmup): fn() diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 2f7716a93c..dc073d0e5c 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -242,7 +242,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= - CAUSAL, scale=sm_scale), warmup=10, rep=10, + CAUSAL, scale=sm_scale), m_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'triton': @@ -256,7 +256,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): ), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, kernel_name='_attn_fwd') elif provider == 'xetla': @@ -272,7 +272,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, kernel_name='gpu::xetla::fmha::FmhaForwardKernel<') else: diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index dfd080ac34..3f17ac4a55 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -125,18 +125,18 @@ def benchmark(M, N, provider): quantiles = [0.5, 0.0, 1.0] if provider == "torch-native": _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, - warmup=10, rep=10) + n_warmup=10, n_repeat=10) if provider == "triton": out = torch.empty_like(x, device="xpu") triton_fn = lambda: softmax(x, out) torch_fn = lambda: torch.softmax(x, axis=-1) benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10, kernel_name="softmax_kernel") elif provider == "torch-jit": - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10, - rep=10) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, + n_warmup=10, n_repeat=10) elif provider == "xetla": name = f"softmax_shape_{M}_{N}" @@ -154,7 +154,7 @@ def benchmark(M, N, provider): "softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0", "softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0", } - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10, kernel_name=kernels_name[name]) else: diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index f54ef2abdd..b70465ee71 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -283,7 +283,7 @@ def benchmark(B, M, N, K, provider): if BENCHMARKING_METHOD == 'PYTORCH_LEGACY_PROFILER_USING_IPEX': # Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method do_bench = do_bench_elapsed_time - _, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), warmup=10, rep=10, + _, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10, n_repeat=10, quantiles=quantiles, kernel_name='gemm_kernel') elif provider == 'triton': assert len(a.shape) == len(b.shape), 'Incompatible sizes' @@ -296,7 +296,8 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name='matmul_kernel_with_block_pointers') elif provider == 'xetla': if B == 1: @@ -340,8 +341,8 @@ def benchmark(B, M, N, K, provider): } # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, - kernel_name=kernels_name[name]) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name=kernels_name[name]) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 8be5f6ce01..307100dcfe 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -275,8 +275,8 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - kernel_name=kernel_name) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name=kernel_name) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index f504aa9952..85bb594ade 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -277,8 +277,8 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32)) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - kernel_name=kernel_name) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name=kernel_name) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 3e9836e225..30ed124d44 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -265,8 +265,8 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - kernel_name=kernel_name) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name=kernel_name) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 6af22c2603..4aa1910591 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -148,7 +148,7 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'triton': c = torch.empty((M, N), device='xpu', dtype=torch.float32) @@ -156,7 +156,7 @@ def benchmark(M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, kernel_name='_kernel') else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index e8179cd45a..94644c61f2 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -271,14 +271,15 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'triton': c = torch.empty((M, N), device=a.device, dtype=torch.float32) triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name=['first_wave', 'full_tiles']) else: raise NotImplementedError(f'Unsupported provider {provider}')