From b1d2a0bea0b4674b5ad585402dfadb632a2a40f2 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 16 Sep 2024 11:16:29 +0000 Subject: [PATCH 01/17] Increase 'warmup' and 'rep' for FA benchmark Signed-off-by: Anatoly Myachev --- .../flash_attention_fwd_benchmark.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 78a4a9de12..884f3d013b 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -217,10 +217,11 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] + warmup, rep = 100, 100 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= - False, scale=sm_scale), warmup=10, rep=10, + False, scale=sm_scale), warmup=warmup, rep=rep, quantiles=quantiles, fast_flush=False) elif provider == 'triton': @@ -231,7 +232,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=warmup, rep=rep, quantiles=quantiles, fast_flush=False) elif provider == 'xetla': @@ -246,7 +247,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=warmup, rep=rep, quantiles=quantiles, fast_flush=False) else: From 5ebbd013df260361a5e6f33c2c40603fe1694601 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 16 Sep 2024 11:53:26 +0000 Subject: [PATCH 02/17] Use 150ms Signed-off-by: Anatoly Myachev --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 884f3d013b..f01ca40a87 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -217,7 +217,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 100, 100 + warmup, rep = 150, 150 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= From b1cc599a72d72cd65c758a1968cef702b5b78312 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 17 Sep 2024 12:45:54 +0200 Subject: [PATCH 03/17] Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index f01ca40a87..5141100afe 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -217,7 +217,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 150, 150 + warmup, rep = 200, 200 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= From 0ad146f2741022005dbf9673a580b40e2baed5be Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 17 Sep 2024 22:28:42 +0200 Subject: [PATCH 04/17] Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 5141100afe..33a0e76cbd 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -217,7 +217,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 200, 200 + warmup, rep = 300, 300 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= From 81fec9a79a5826140dade00e8c663385c5ac6951 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 23 Sep 2024 15:08:52 +0200 Subject: [PATCH 05/17] Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 33a0e76cbd..63e8446d83 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -217,7 +217,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 300, 300 + warmup, rep = 10, 300 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= From 8f81c13631390c02f6924b5adb18b534d9d80333 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 23 Sep 2024 17:08:43 +0200 Subject: [PATCH 06/17] Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index e50f780792..1031dcbc2e 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -217,7 +217,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 10, 300 + warmup, rep = 10, 200 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= From b2d3398bada9a6d083004def893637f29da1f3fb Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Sun, 29 Sep 2024 19:55:36 +0200 Subject: [PATCH 07/17] Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 4a963397e7..2c697bb0c5 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -234,7 +234,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 10, 200 + warmup, rep = 10, 300 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= From fe806b1dbbfe82a7ffcae848c0a56ab046149158 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 30 Sep 2024 10:52:02 +0200 Subject: [PATCH 08/17] Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 2c697bb0c5..fd0aeb47a7 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -234,7 +234,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 10, 300 + warmup, rep = 10, 400 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= From bf49b0d5fe590973c2ff38fcab76ea4a589816e8 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 30 Sep 2024 09:25:35 +0000 Subject: [PATCH 09/17] fix after merge Signed-off-by: Anatoly Myachev --- .../flash_attention_fwd_benchmark.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index fd0aeb47a7..5735690a24 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -234,7 +234,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 10, 400 + warmup, rep = 10, 200 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= @@ -258,7 +258,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): ), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=warmup, rep=rep, + quantiles=quantiles) elif provider == 'xetla': module_name = f'flash_attn_causal_{CAUSAL}'.lower() From 74936327c54d6cac5bbcd9ae34a4f8739a3453d5 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 30 Sep 2024 12:16:01 +0200 Subject: [PATCH 10/17] Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 5735690a24..b62a1f2d72 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -234,7 +234,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 10, 200 + warmup, rep = 10, 300 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= From 524f81d6da59b4826588011f971918a91bda3ea7 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 30 Sep 2024 13:10:46 +0200 Subject: [PATCH 11/17] Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index b62a1f2d72..d216402968 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -234,7 +234,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 10, 300 + warmup, rep = 10, 400 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= From 4d408644cdb2c0953f23e886b7f47b7901cc7c4c Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 30 Sep 2024 14:02:51 +0200 Subject: [PATCH 12/17] Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index d216402968..61bc786c3a 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -234,7 +234,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 10, 400 + warmup, rep = 10, 500 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= From b0d91ce5d5901d696e69e5618602fe7bc76adfa5 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 30 Sep 2024 15:03:42 +0200 Subject: [PATCH 13/17] Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 61bc786c3a..e2c6fa7e08 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -234,7 +234,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 10, 500 + warmup, rep = 10, 600 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= From e1c4f9fdfdc432330c3854b11e610d365558e558 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 14 Oct 2024 16:37:01 +0000 Subject: [PATCH 14/17] Change do_bench* signatures Signed-off-by: Anatoly Myachev --- .../benchmark_testing.py | 89 +++++++++---------- 1 file changed, 43 insertions(+), 46 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index b289a096e2..37335677b6 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -36,7 +36,7 @@ def _summarize_statistics(times, quantiles, return_mode): return getattr(torch, return_mode)(times).item() -def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", +def do_bench_ipex(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with @@ -44,10 +44,10 @@ def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, ret :param fn: Function to benchmark :type fn: Callable - :param warmup: Warmup time (in ms) - :type warmup: int - :param rep: Repetition time (in ms) - :type rep: int + :param n_warmup: Number of repetitions for warmup + :type n_warmup: int + :param n_repeat: Number of repetitions to collect measurements + :type n_repeat: int :param grad_to_none: Reset the gradient of the provided tensor to None :type grad_to_none: torch.tensor, optional :param quantiles: Performance percentile to return in addition to the median. @@ -69,20 +69,6 @@ def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, ret cache_size = 256 * 1024 * 1024 cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) - # Estimate the runtime of the function - start_event = torch.xpu.Event(enable_timing=True) - end_event = torch.xpu.Event(enable_timing=True) - start_event.record() - for _ in range(5): - cache.zero_() - fn() - end_event.record() - synchronize() - estimate_ms = start_event.elapsed_time(end_event) / 5 - - # compute number of warmup and repeat - n_warmup = max(1, int(warmup / estimate_ms)) - n_repeat = max(1, int(rep / estimate_ms)) # Warm-up for _ in range(n_warmup): fn() @@ -121,18 +107,18 @@ def extract_kernels(funcs): return _summarize_statistics(times, quantiles, return_mode) -def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", - kernel_name=None): # pylint: disable=unused-argument +def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", + device="xpu", kernel_name=None): # pylint: disable=unused-argument """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. :param fn: Function to benchmark :type fn: Callable - :param warmup: Warmup time (in ms) - :type warmup: int - :param rep: Repetition time (in ms) - :type rep: int + :param n_warmup: Number of repetitions for warmup + :type n_warmup: int + :param n_repeat: Number of repetitions to collect measurements + :type n_repeat: int :param grad_to_none: Reset the gradient of the provided tensor to None :type grad_to_none: torch.tensor, optional :param quantiles: Performance percentile to return in addition to the median. @@ -142,24 +128,49 @@ def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=N import torch from triton.testing import do_bench as triton_do_bench - times = triton_do_bench(fn, warmup=warmup, rep=rep, grad_to_none=grad_to_none, return_mode="all", + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) + + # Estimate the runtime of the function + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # The cache is also maintained in `triton_do_bench` function, + # there is no need to duplicate the amount of memory used. + del cache + + # compute warmup and repeat times + warmup_time = n_warmup * estimate_ms + rep_time = n_repeat * estimate_ms + + times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all", device_type=device) times = torch.tensor(times, dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode) -def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", - device="xpu", sync_submitting=True, kernel_name=None): +def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, + return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. :param fn: Function to benchmark :type fn: Callable - :param warmup: Warmup time (in ms) - :type warmup: int - :param rep: Repetition time (in ms) - :type rep: int + :param n_warmup: Number of repetitions for warmup + :type n_warmup: int + :param n_repeat: Number of repetitions to collect measurements + :type n_repeat: int :param grad_to_none: Reset the gradient of the provided tensor to None :type grad_to_none: torch.tensor, optional :param quantiles: Performance percentile to return in addition to the median. @@ -179,20 +190,6 @@ def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None cache_size = 256 * 1024 * 1024 cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) - # Estimate the runtime of the function - start_event = torch.xpu.Event(enable_timing=True) - end_event = torch.xpu.Event(enable_timing=True) - start_event.record() - for _ in range(5): - cache.zero_() - fn() - end_event.record() - synchronize() - estimate_ms = start_event.elapsed_time(end_event) / 5 - - # compute number of warmup and repeat - n_warmup = max(1, int(warmup / estimate_ms)) - n_repeat = max(1, int(rep / estimate_ms)) # Warm-up for _ in range(n_warmup): fn() From a1fd0f97ff452e5cf825b688f5e53b8f4afa0ef2 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 14 Oct 2024 16:40:30 +0000 Subject: [PATCH 15/17] cleanup Signed-off-by: Anatoly Myachev --- .../flash_attention_fwd_benchmark.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 0f60e607da..2f7716a93c 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -239,11 +239,10 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype) sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] - warmup, rep = 10, 600 if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= - CAUSAL, scale=sm_scale), warmup=warmup, rep=rep, + CAUSAL, scale=sm_scale), warmup=10, rep=10, quantiles=quantiles) elif provider == 'triton': @@ -257,7 +256,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): ), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=warmup, rep=rep, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, kernel_name='_attn_fwd') elif provider == 'xetla': @@ -273,7 +272,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=warmup, rep=rep, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, kernel_name='gpu::xetla::fmha::FmhaForwardKernel<') else: From f16b1496c90855182192d4771dbc81b3b4d5bc67 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 14 Oct 2024 17:01:14 +0000 Subject: [PATCH 16/17] fixes Signed-off-by: Anatoly Myachev --- .../flash_attention_fwd_benchmark.py | 6 +++--- benchmarks/triton_kernels_benchmark/fused_softmax.py | 10 +++++----- benchmarks/triton_kernels_benchmark/gemm_benchmark.py | 6 +++--- .../gemm_postop_addmatrix_benchmark.py | 2 +- .../gemm_postop_gelu_benchmark.py | 2 +- .../gemm_preop_exp_benchmark.py | 2 +- .../triton_kernels_benchmark/gemm_splitk_benchmark.py | 4 ++-- .../triton_kernels_benchmark/gemm_streamk_benchmark.py | 4 ++-- 8 files changed, 18 insertions(+), 18 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 2f7716a93c..f1fc1db878 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -242,7 +242,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= - CAUSAL, scale=sm_scale), warmup=10, rep=10, + CAUSAL, scale=sm_scale), m_warmup=10, n_rep=10, quantiles=quantiles) elif provider == 'triton': @@ -256,7 +256,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): ), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, kernel_name='_attn_fwd') elif provider == 'xetla': @@ -272,7 +272,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_rep=10, quantiles=quantiles, kernel_name='gpu::xetla::fmha::FmhaForwardKernel<') else: diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index dfd080ac34..acfa0e1eea 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -125,18 +125,18 @@ def benchmark(M, N, provider): quantiles = [0.5, 0.0, 1.0] if provider == "torch-native": _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, - warmup=10, rep=10) + n_warmup=10, n_rep=10) if provider == "triton": out = torch.empty_like(x, device="xpu") triton_fn = lambda: softmax(x, out) torch_fn = lambda: torch.softmax(x, axis=-1) benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_rep=10, kernel_name="softmax_kernel") elif provider == "torch-jit": - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10, - rep=10) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, + n_warmup=10, n_rep=10) elif provider == "xetla": name = f"softmax_shape_{M}_{N}" @@ -154,7 +154,7 @@ def benchmark(M, N, provider): "softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0", "softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0", } - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_rep=10, kernel_name=kernels_name[name]) else: diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index f54ef2abdd..893f12cbae 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -283,7 +283,7 @@ def benchmark(B, M, N, K, provider): if BENCHMARKING_METHOD == 'PYTORCH_LEGACY_PROFILER_USING_IPEX': # Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method do_bench = do_bench_elapsed_time - _, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), warmup=10, rep=10, + _, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10, n_rep=10, quantiles=quantiles, kernel_name='gemm_kernel') elif provider == 'triton': assert len(a.shape) == len(b.shape), 'Incompatible sizes' @@ -296,7 +296,7 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, kernel_name='matmul_kernel_with_block_pointers') elif provider == 'xetla': if B == 1: @@ -340,7 +340,7 @@ def benchmark(B, M, N, K, provider): } # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_rep=10, quantiles=quantiles, kernel_name=kernels_name[name]) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 8be5f6ce01..291998429b 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -275,7 +275,7 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, kernel_name=kernel_name) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index f504aa9952..305a464bc9 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -277,7 +277,7 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32)) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, kernel_name=kernel_name) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 3e9836e225..a69429aacc 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -265,7 +265,7 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, kernel_name=kernel_name) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 6af22c2603..7b87fe384c 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -148,7 +148,7 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_rep=10, quantiles=quantiles) elif provider == 'triton': c = torch.empty((M, N), device='xpu', dtype=torch.float32) @@ -156,7 +156,7 @@ def benchmark(M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, kernel_name='_kernel') else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index e8179cd45a..ae41b8a385 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -271,14 +271,14 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_rep=10, quantiles=quantiles) elif provider == 'triton': c = torch.empty((M, N), device=a.device, dtype=torch.float32) triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, kernel_name=['first_wave', 'full_tiles']) else: raise NotImplementedError(f'Unsupported provider {provider}') From 565d87cdb998339a14ff2bf88faee8c6c5fdf602 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 14 Oct 2024 17:19:08 +0000 Subject: [PATCH 17/17] fix Signed-off-by: Anatoly Myachev --- .../flash_attention_fwd_benchmark.py | 6 +++--- benchmarks/triton_kernels_benchmark/fused_softmax.py | 8 ++++---- benchmarks/triton_kernels_benchmark/gemm_benchmark.py | 9 +++++---- .../gemm_postop_addmatrix_benchmark.py | 4 ++-- .../gemm_postop_gelu_benchmark.py | 4 ++-- .../triton_kernels_benchmark/gemm_preop_exp_benchmark.py | 4 ++-- .../triton_kernels_benchmark/gemm_splitk_benchmark.py | 4 ++-- .../triton_kernels_benchmark/gemm_streamk_benchmark.py | 5 +++-- 8 files changed, 23 insertions(+), 21 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index f1fc1db878..dc073d0e5c 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -242,7 +242,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= - CAUSAL, scale=sm_scale), m_warmup=10, n_rep=10, + CAUSAL, scale=sm_scale), m_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'triton': @@ -256,7 +256,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): ), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, kernel_name='_attn_fwd') elif provider == 'xetla': @@ -272,7 +272,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, kernel_name='gpu::xetla::fmha::FmhaForwardKernel<') else: diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index acfa0e1eea..3f17ac4a55 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -125,18 +125,18 @@ def benchmark(M, N, provider): quantiles = [0.5, 0.0, 1.0] if provider == "torch-native": _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, - n_warmup=10, n_rep=10) + n_warmup=10, n_repeat=10) if provider == "triton": out = torch.empty_like(x, device="xpu") triton_fn = lambda: softmax(x, out) torch_fn = lambda: torch.softmax(x, axis=-1) benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_rep=10, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10, kernel_name="softmax_kernel") elif provider == "torch-jit": _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, - n_warmup=10, n_rep=10) + n_warmup=10, n_repeat=10) elif provider == "xetla": name = f"softmax_shape_{M}_{N}" @@ -154,7 +154,7 @@ def benchmark(M, N, provider): "softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0", "softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0", } - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_rep=10, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10, kernel_name=kernels_name[name]) else: diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 893f12cbae..b70465ee71 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -283,7 +283,7 @@ def benchmark(B, M, N, K, provider): if BENCHMARKING_METHOD == 'PYTORCH_LEGACY_PROFILER_USING_IPEX': # Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method do_bench = do_bench_elapsed_time - _, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10, n_rep=10, + _, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10, n_repeat=10, quantiles=quantiles, kernel_name='gemm_kernel') elif provider == 'triton': assert len(a.shape) == len(b.shape), 'Incompatible sizes' @@ -296,7 +296,8 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name='matmul_kernel_with_block_pointers') elif provider == 'xetla': if B == 1: @@ -340,8 +341,8 @@ def benchmark(B, M, N, K, provider): } # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_rep=10, quantiles=quantiles, - kernel_name=kernels_name[name]) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name=kernels_name[name]) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 291998429b..307100dcfe 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -275,8 +275,8 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, - kernel_name=kernel_name) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name=kernel_name) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 305a464bc9..85bb594ade 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -277,8 +277,8 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32)) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, - kernel_name=kernel_name) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name=kernel_name) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index a69429aacc..30ed124d44 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -265,8 +265,8 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, - kernel_name=kernel_name) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name=kernel_name) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 7b87fe384c..4aa1910591 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -148,7 +148,7 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_rep=10, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'triton': c = torch.empty((M, N), device='xpu', dtype=torch.float32) @@ -156,7 +156,7 @@ def benchmark(M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, kernel_name='_kernel') else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index ae41b8a385..94644c61f2 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -271,14 +271,15 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_rep=10, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'triton': c = torch.empty((M, N), device=a.device, dtype=torch.float32) triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_rep=10, quantiles=quantiles, + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name=['first_wave', 'full_tiles']) else: raise NotImplementedError(f'Unsupported provider {provider}')