|
| 1 | +# This benchmark requires a Pytorch version with FlexAttention support for XPU available |
| 2 | +from functools import lru_cache |
| 3 | +import os |
| 4 | +from torch.nn.attention.flex_attention import ( |
| 5 | + create_block_mask, |
| 6 | + flex_attention, |
| 7 | +) |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn.functional as F |
| 11 | +import triton_kernels_benchmark as benchmark_suit |
| 12 | +from triton_kernels_benchmark import xetla_kernel |
| 13 | + |
| 14 | +# Compile the flex_attention function |
| 15 | +flex_attention = torch.compile(flex_attention, dynamic=False) |
| 16 | + |
| 17 | + |
| 18 | +@lru_cache |
| 19 | +def create_block_mask_cached(score_mod, B, H, M, N, device='xpu'): |
| 20 | + block_mask = create_block_mask(score_mod, B, H, M, N, device=device) |
| 21 | + return block_mask |
| 22 | + |
| 23 | + |
| 24 | +def causal_mask(_, __, q_idx, kv_idx): |
| 25 | + return q_idx >= kv_idx |
| 26 | + |
| 27 | + |
| 28 | +# Kernel profiling for Backward mode is not working as expected: |
| 29 | +# For details: https://github.com/pytorch/pytorch/issues/144778 |
| 30 | +@benchmark_suit.perf_report( |
| 31 | + benchmark_suit.Benchmark( |
| 32 | + x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'], |
| 33 | + x_vals=[[z, h, 16384 // z, dhead, causal, mode] |
| 34 | + for z in [1, 2, 4, 8, 16, 32] |
| 35 | + for (h, dhead) in [(16, 128), (32, 64)] |
| 36 | + for causal in [True] |
| 37 | + for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] # |
| 38 | + + [[4, 48, 1024, 64, True, mode] for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] # |
| 39 | + + [[z, h, 1024, dhead, True, mode] |
| 40 | + for z in [1, 2, 4, 8, 16, 32, 64] |
| 41 | + for (h, dhead) in [(8, 128), (32, 96), (4, 128)] |
| 42 | + for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]], |
| 43 | + line_arg='provider', |
| 44 | + line_vals=['triton', 'xetla'], |
| 45 | + line_names=['Triton', 'XeTLA'], |
| 46 | + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], |
| 47 | + ylabel=['GB/s', 'TFlops'], |
| 48 | + plot_name='flexAttnCausal-performance', |
| 49 | + args={}, |
| 50 | + )) |
| 51 | +def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider): |
| 52 | + assert MODE in ['fwd', 'bwd'] |
| 53 | + assert CAUSAL |
| 54 | + dtype = torch.float16 |
| 55 | + q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) |
| 56 | + k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) |
| 57 | + v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) |
| 58 | + sm_scale = 0.125 |
| 59 | + if MODE == 'bwd': |
| 60 | + sm_scale = 1.3 |
| 61 | + |
| 62 | + quantiles = [0.5, 0.0, 1.0] |
| 63 | + if provider == 'triton': |
| 64 | + block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX, N_CTX, device=q.device) |
| 65 | + triton_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale) |
| 66 | + if MODE == 'bwd': |
| 67 | + triton_o = triton_fn() |
| 68 | + triton_do = torch.randn_like(triton_o) |
| 69 | + triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True) |
| 70 | + torch_fn = lambda: F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), is_causal=True, scale=sm_scale).to( |
| 71 | + torch.float32) |
| 72 | + if MODE == 'bwd': |
| 73 | + torch_o = torch_fn() |
| 74 | + torch_do = torch.randn_like(torch_o) |
| 75 | + torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True) |
| 76 | + if MODE == 'fwd': |
| 77 | + atol = 1e-1 if N_CTX == 16384 else 1e-2 |
| 78 | + benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch') |
| 79 | + else: |
| 80 | + benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch') |
| 81 | + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) |
| 82 | + |
| 83 | + elif provider == 'xetla': |
| 84 | + xetla_fn = None |
| 85 | + if MODE == 'fwd': |
| 86 | + module_name = 'flash_attn_causal_True'.lower() |
| 87 | + func = getattr(xetla_kernel, module_name) |
| 88 | + out = torch.empty_like(q, device='xpu', dtype=dtype) |
| 89 | + size_score = Z * H * N_CTX * N_CTX |
| 90 | + size_attn_mask = Z * N_CTX * N_CTX |
| 91 | + dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8) |
| 92 | + bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype) |
| 93 | + size_ml = Z * H * N_CTX |
| 94 | + m = torch.empty((size_ml, ), device='xpu', dtype=torch.float) |
| 95 | + l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) |
| 96 | + xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) |
| 97 | + if MODE == 'bwd': |
| 98 | + module_name = 'flash_attn_bwd_causal_True'.lower() |
| 99 | + func = getattr(xetla_kernel, module_name) |
| 100 | + grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) |
| 101 | + bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) |
| 102 | + dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8) |
| 103 | + out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) |
| 104 | + log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) |
| 105 | + workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) |
| 106 | + grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) |
| 107 | + alpha = sm_scale |
| 108 | + dropout_prob = 0 |
| 109 | + grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) |
| 110 | + grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True) |
| 111 | + grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True) |
| 112 | + grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True) |
| 113 | + bias_strideB = -1 |
| 114 | + bias_strideN = -1 |
| 115 | + bias_strideF = -1 |
| 116 | + attn_mask_padding = 0 |
| 117 | + |
| 118 | + xetla_fn = lambda: func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha, |
| 119 | + dropout_prob, grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX, |
| 120 | + N_CTX, bias_strideB, bias_strideN, bias_strideF, attn_mask_padding) |
| 121 | + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) |
| 122 | + |
| 123 | + else: |
| 124 | + raise NotImplementedError(f'Unsupported provider {provider}') |
| 125 | + |
| 126 | + tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3) |
| 127 | + gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3) |
| 128 | + |
| 129 | + if MODE == 'bwd': |
| 130 | + tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3) |
| 131 | + gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3) |
| 132 | + |
| 133 | + return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv |
| 134 | + |
| 135 | + |
| 136 | +if __name__ == '__main__': |
| 137 | + benchmark.run(show_plots=False, print_data=True) |
0 commit comments