diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index 2149a8e..7c34ba2 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -283,11 +283,9 @@ def dynamic_mask_attention_triton( attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv) # Ensure correct data types and memory layout for Triton function - query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim] # Call the Triton implementation attn_outputs = triton_dmattn_func( @@ -729,6 +727,239 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): return all_passed +def test_triton_backward_equivalence(accuracy_threshold=0.95): + """Test backward pass equivalence between Python prototype and Triton implementation.""" + print("\n" + "๐Ÿš€" + "=" * 76 + "๐Ÿš€") + print("๐Ÿ”ฌ Testing backward Pass Equivalence: Python Prototype vs Triton Implementation") + print("๐Ÿš€" + "=" * 76 + "๐Ÿš€") + + # Check if Triton implementation is available + if triton_dmattn_func is None: + print("โŒ Triton implementation not available, skipping test.") + return False + + # Set random seed for reproducibility + torch.manual_seed(0) + + # Test different parameter configurations + # If you encounter NAN issues when running multiple configurations, try running a single configuration + # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) + test_configs = [ + # Head dim 32 + (1, 2, 1, 128, 128, 32, False), + (1, 2, 1, 128, 128, 32, True), + (1, 2, 1, 256, 256, 32, False), + (1, 2, 1, 256, 256, 32, True), + (1, 2, 1, 512, 512, 32, False), + (1, 2, 1, 512, 512, 32, True), + (1, 2, 1, 1024, 1024, 32, False), + (1, 2, 1, 1024, 1024, 32, True), + (1, 2, 1, 2048, 2048, 32, False), + (1, 2, 1, 2048, 2048, 32, True), + (1, 2, 1, 4096, 4096, 32, False), + (1, 2, 1, 4096, 4096, 32, True), + + # Head dim 64 + (1, 2, 1, 128, 128, 64, False), + (1, 2, 1, 128, 128, 64, True), + (1, 2, 1, 256, 256, 64, False), + (1, 2, 1, 256, 256, 64, True), + (1, 2, 1, 512, 512, 64, False), + (1, 2, 1, 512, 512, 64, True), + (1, 2, 1, 1024, 1024, 64, False), + (1, 2, 1, 1024, 1024, 64, True), + (1, 2, 1, 2048, 2048, 64, False), + (1, 2, 1, 2048, 2048, 64, True), + (1, 2, 1, 4096, 4096, 64, False), + (1, 2, 1, 4096, 4096, 64, True), + + # Head dim 96 + (1, 2, 1, 128, 128, 96, False), + (1, 2, 1, 128, 128, 96, True), + (1, 2, 1, 256, 256, 96, False), + (1, 2, 1, 256, 256, 96, True), + (1, 2, 1, 512, 512, 96, False), + (1, 2, 1, 512, 512, 96, True), + (1, 2, 1, 1024, 1024, 96, False), + (1, 2, 1, 1024, 1024, 96, True), + (1, 2, 1, 2048, 2048, 96, False), + (1, 2, 1, 2048, 2048, 96, True), + (1, 2, 1, 4096, 4096, 96, False), + (1, 2, 1, 4096, 4096, 96, True), + + # Head dim 128 + (1, 2, 1, 128, 128, 128, False), + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 256, 256, 128, False), + (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 1024, 1024, 128, False), + (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 2048, 2048, 128, False), + (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 4096, 4096, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + + # triton currently supports up to head dim 128 + ] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.bfloat16 + device_icon = "๐Ÿ”ฅ" if device.type == "cuda" else "๐Ÿ’ป" + print(f"{device_icon} Using device: {device}") + + all_passed = True + + for i, config in enumerate(test_configs): + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal = config + + # Progress indicator + progress_filled = "โ–ˆ" * (i + 1) + progress_empty = "โ–‘" * (len(test_configs) - i - 1) + progress_bar = f"[{progress_filled}{progress_empty}]" + + print(f"\n๐Ÿงช Test configuration {i+1}/{len(test_configs)} {progress_bar}") + print(f" ๐Ÿ“Š batch_size={batch_size}, num_heads={num_heads}, num_kv_heads={num_kv_heads}") + print(f" ๐Ÿ“ query_len={query_len}, key_len={key_len}, head_dim={head_dim}") + print(f" ๐Ÿ”’ is_causal={is_causal}") + print(f" ๐ŸŽฏ Accuracy threshold: {accuracy_threshold*100:.1f}%") + + # Create random input data + query_states = torch.randn( + batch_size, num_heads, query_len, head_dim, + device=device, dtype=dtype, requires_grad=True + ) + key_states = torch.randn( + batch_size, num_kv_heads, key_len, head_dim, + device=device, dtype=dtype, requires_grad=True + ) + value_states = torch.randn( + batch_size, num_kv_heads, key_len, head_dim, + device=device, dtype=dtype, requires_grad=True + ) + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, + device=device, dtype=torch.bfloat16 + ) + cache_position = torch.arange(key_len - query_len, key_len, device=device) + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + + # Set scaling factor and keep window size + scaling = head_dim ** -0.5 + window_size = 10240 + + # Clone inputs for Python implementation + query_python = query_states.clone().detach().requires_grad_(True) + key_python = key_states.clone().detach().requires_grad_(True) + value_python = value_states.clone().detach().requires_grad_(True) + attn_bias_python = attn_bias.clone().detach().requires_grad_(True) + causal_mask_python = causal_mask.clone().detach() + + # Run Python implementation + start_time = time.time() + attn_outputs_python, dq_python, dk_python, dv_python, dbias_python = dynamic_mask_attention_python( + query_python, key_python, value_python, + attn_bias_python, causal_mask_python, + scaling, window_size, is_causal + ) + torch.cuda.synchronize() + py_time = time.time() - start_time + + # Clone inputs for Triton implementation + query_triton = query_states.clone().detach().requires_grad_(True) + key_triton = key_states.clone().detach().requires_grad_(True) + value_triton = value_states.clone().detach().requires_grad_(True) + attn_bias_triton = attn_bias.clone().detach().requires_grad_(True) + causal_mask_triton = causal_mask.clone().detach() + + # Run Triton implementation + start_time = time.time() + attn_outputs_triton, dq_triton, dk_triton, dv_triton, dbias_triton = dynamic_mask_attention_triton( + query_triton, key_triton, value_triton, + attn_bias_triton, causal_mask_triton, + scaling, window_size, is_causal + ) + torch.cuda.synchronize() + triton_time = time.time() - start_time + + # Analyze outputs + print(f"\n๐Ÿ” Analyzing differences between Python and Triton outputs:") + is_attn_output_close, max_attn_output_diff, mean_attn_output_diff = analyze_differences( + attn_outputs_python, attn_outputs_triton, accuracy_threshold + ) + + # Analyze dQ gradients + print(f"\n๐Ÿ” Analyzing dQ gradients:") + is_dq_close, max_dq_diff, mean_dq_diff = analyze_differences( + dq_python, dq_triton, accuracy_threshold + ) + + # Analyze dK gradients + print(f"\n๐Ÿ” Analyzing dK gradients:") + is_dk_close, max_dk_diff, mean_dk_diff = analyze_differences( + dk_python, dk_triton, accuracy_threshold + ) + + # Analyze dV gradients + print(f"\n๐Ÿ” Analyzing dV gradients:") + is_dv_close, max_dv_diff, mean_dv_diff = analyze_differences( + dv_python, dv_triton, accuracy_threshold + ) + + # Analyze dBias gradients + print(f"\n๐Ÿ” Analyzing dBias gradients:") + is_dbias_close, max_dbias_diff, mean_dbias_diff = analyze_differences( + dbias_python, dbias_triton, accuracy_threshold + ) + + # Report performance difference + speedup = py_time / triton_time if triton_time > 0 else float('inf') + print(f"\nโšก Performance comparison:") + print(f" ๐Ÿ Python implementation: {py_time*1000:.2f} ms") + print(f" ๐Ÿš€ Triton implementation: {triton_time*1000:.2f} ms") + print(f" ๐Ÿ“ˆ Speedup: {speedup:.2f}x") + + # Check if all gradients pass + is_close = (is_attn_output_close and is_dq_close and is_dk_close and is_dv_close and is_dbias_close) + test_result = "Passed" if is_close else "Failed" + result_icon = "โœ…" if is_close else "โŒ" + all_passed = all_passed and is_close + print(f"\n{result_icon} Test result: {test_result}") + + # If test fails with large difference, can exit early + if not is_close and max_attn_output_diff > 1e-2: + print(" โš ๏ธ Difference too large, stopping subsequent tests.") + break + if not is_close and max_dq_diff > 1e-2: + print(" โš ๏ธ Difference too large, stopping subsequent tests.") + break + if not is_close and max_dk_diff > 1e-2: + print(" โš ๏ธ Difference too large, stopping subsequent tests.") + break + if not is_close and max_dv_diff > 1e-2: + print(" โš ๏ธ Difference too large, stopping subsequent tests.") + break + if not is_close and max_dbias_diff > 1e-2: + print(" โš ๏ธ Difference too large, stopping subsequent tests.") + break + del query_states, key_states, value_states, attn_bias, causal_mask, cache_position, dq_python, dk_python, dv_python, dbias_python, dq_triton, dk_triton, dv_triton, dbias_triton + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + + print("\n" + "๐Ÿ" + "=" * 76 + "๐Ÿ") + summary_icon = "๐ŸŽ‰" if all_passed else "๐Ÿ˜ž" + print(f"{summary_icon} Backward Equivalence Test Summary: {'All Passed' if all_passed else 'Some Tests Failed'}") + print("๐Ÿ" + "=" * 76 + "๐Ÿ") + + return all_passed + def main(): """ Test backward pass equivalence between Python prototype and various implementations @@ -782,9 +1013,9 @@ def main(): print("\n" + "๐Ÿ“" + " Starting Python vs CUDA Backward Tests " + "๐Ÿ“") test_results['cuda'] = test_cuda_backward_equivalence(args.accuracy_threshold) - # if args.test_type in ['all', 'triton']: - # print("\n" + "๐Ÿ”ฅ" + " Starting Python vs Triton Backward Tests " + "๐Ÿ”ฅ") - # test_results['triton'] = test_triton_backward_equivalence(args.accuracy_threshold) + if args.test_type in ['all', 'triton']: + print("\n" + "๐Ÿ”ฅ" + " Starting Python vs Triton Backward Tests " + "๐Ÿ”ฅ") + test_results['triton'] = test_triton_backward_equivalence(args.accuracy_threshold) # if args.test_type in ['all', 'flex']: # print("\n" + "๐ŸŒŸ" + " Starting Python vs Flex Attention Backward Tests " + "๐ŸŒŸ") diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index c94500b..66141cb 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -1,3 +1,4 @@ +from typing import Optional import math import torch @@ -5,15 +6,41 @@ import triton.language as tl -# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 -# @triton.autotune( -# configs=[ -# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), -# # This config has a race condition when EVEN_M == False, disabling it for now. -# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), -# ], -# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'] -# ) +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_warps=8, + num_stages=1, + ), + ], + key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'HAS_MASK', 'HAS_BIAS', 'BLOCK_HEADDIM'] +) @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, @@ -30,7 +57,6 @@ def _fwd_kernel( Bias, Out, Lse, - TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug softmax_scale, stride_qb, stride_qh, @@ -51,13 +77,19 @@ def _fwd_kernel( stride_oh, stride_om, nheads, + nheads_k, + nheads_mask, + nheads_bias, + h_h_k_ratio, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, - CACHE_KEY_SEQLEN_Q, - CACHE_KEY_SEQLEN_K, + CACHE_KEY_SEQLEN_Q: tl.constexpr, + CACHE_KEY_SEQLEN_K: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_MASK: tl.constexpr, + HAS_BIAS: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, @@ -68,43 +100,55 @@ def _fwd_kernel( start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads - off_h = off_hb % nheads + off_hq = off_hb % nheads + off_hk = off_hq // h_h_k_ratio + if HAS_MASK: + if nheads_mask == 1: + off_hmask = 0 + elif nheads_mask == nheads_k: + off_hmask = off_hk + else: + off_hmask = off_hq + if HAS_BIAS: + if nheads_bias == 1: + off_hbbias = 0 + elif nheads_bias == nheads_k: + off_hbbias = off_hk + else: + off_hbbias = off_hq # off_b = tl.program_id(1) # off_h = tl.program_id(2) # off_hb = off_b * nheads + off_h - # initialize offsets + + # Initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V, Mask, Bias - # Adding parenthesis around indexing might use int32 math instead of int64 math? - # https://github.com/openai/triton/issues/741 - # I'm seeing a tiny bit of difference (5-7us) q_ptrs = ( - Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + Q + off_b * stride_qb + off_hq * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) ) k_ptrs = ( - K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + K + off_b * stride_kb + off_hk * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) ) v_ptrs = ( - V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + V + off_b * stride_vb + off_hk * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) ) m_ptrs = ( - Mask + off_b * stride_mb + off_h * stride_mh + (offs_m[:, None] * stride_mm + offs_n[None, :]) - ) + Mask + off_b * stride_mb + off_hmask * stride_mh + (offs_m[:, None] * stride_mm + offs_n[None, :]) + ) if HAS_MASK else None b_ptrs = ( - Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) - ) - - # initialize pointer to m and l - t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m - lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + Bias + off_b * stride_bb + off_hbbias * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) + ) if HAS_BIAS else None + + # Initialize pointer to m and l + lse_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) - # load q: it will stay in SRAM throughout - # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call - # tl.load(q_ptrs), we get the wrong output! - if EVEN_M & EVEN_N: + + # Load q: it will stay in SRAM throughout + if EVEN_M: if EVEN_HEADDIM: q = tl.load(q_ptrs) else: @@ -116,133 +160,134 @@ def _fwd_kernel( q = tl.load( q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) - # loop over k, v and update accumulator + + # Scale q + q = (q * softmax_scale).to(q.dtype) + + # Loop over k, v and update accumulator end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - # Load k - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - k = tl.load(k_ptrs + start_n * stride_kn) - else: - k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) + if HAS_MASK: + # Load mask + if EVEN_M & EVEN_N: + mask = tl.load(m_ptrs + start_n) else: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, + mask = tl.load( + m_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), + other=False ) - # compute acc_s - acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - acc_s += tl.dot(q, tl.trans(k)) - - # Trying to combine the two masks seem to make the result wrong - # Apply sequence length mask - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) - # Apply causal mask - if IS_CAUSAL: - acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) - - # Load mask - if EVEN_M & EVEN_N: - mask = tl.load(m_ptrs + start_n) - else: - mask = tl.load( - m_ptrs + start_n, - mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), - other=0.0 - ) - - # Check if any element in mask is non-zero - # BUG: Triton needs to determine the control flow at compile time. - # Dynamic conditions at runtime can undermine this optimization. - # any_active = tl.sum(mask) != 0 - # Apply dynamic mask - acc_s += tl.where(mask > 0.0, 0.0, float("-inf")) - - # Load bias - if EVEN_M & EVEN_N: - bias = tl.load(b_ptrs + start_n).to(tl.float32) + # Check if any element in mask is non-zero + any_active = tl.reduce_or(mask, axis=None) else: - bias = tl.load( - b_ptrs + start_n, - mask=(offs_m[:, None] < seqlen_q) - & ((start_n + offs_n)[None, :] < seqlen_k), - other=0.0, - ).to(tl.float32) - - # Apply scaling and bias - # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler - # can then fuse the mult and add into an fma instruction. But if we have bias we need to - # to multiply with softmax_scale here. - acc_s = acc_s * softmax_scale + bias - # acc_s = tl.where(acc_s != float("-inf"), acc_s * softmax_scale + bias, acc_s) - m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) - p = tl.exp(acc_s - m_ij[:, None]) - l_ij = tl.sum(p, 1) - - # scale acc_o - acc_o_scale = tl.exp(m_i - m_ij) - - # update output accumulator - # BUG: have to store and immediately load - tl.store(t_ptrs, acc_o_scale) - acc_o_scale = tl.load(t_ptrs) - acc_o = acc_o * acc_o_scale[:, None] - - # load v - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - v = tl.load(v_ptrs + start_n * stride_vn) + any_active = True + + # Skip this iteration if no active elements + if any_active: + + # Load k + if EVEN_N: + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) else: - v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + + if HAS_BIAS: + # Load bias + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) + & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + acc_s = bias else: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, - ) - acc_o += tl.dot(p.to(v.dtype), v) + acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + # Compute acc_s + acc_s += tl.dot(q, tl.trans(k)) - # update statistics - m_i = m_ij - l_i_new = tl.exp(lse_i - m_ij) + l_ij - lse_i = m_ij + tl.log(l_i_new) + # Apply masks + # Trying to combine the three masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if HAS_MASK: + acc_s += tl.where(mask, 0, float("-inf")) + + # Compute p + m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) + p = tl.exp(acc_s - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # Scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # Update output accumulator + acc_o = acc_o * acc_o_scale[:, None] + + # Load v + if EVEN_N: + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # Compute acc_o + acc_o += tl.dot(p.to(v.dtype), v) + + # Update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) o_scale = tl.exp(m_i - lse_i) - # BUG: have to store and immediately load - tl.store(t_ptrs, o_scale) - o_scale = tl.load(t_ptrs) acc_o = acc_o * o_scale[:, None] - # rematerialize offsets to save registers + # Rematerialize offsets to save registers start_m = tl.program_id(0) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # write back l and m + # Write back l and m lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m tl.store(lse_ptrs, lse_i) - # initialize pointers to output + # Initialize pointers to output offs_d = tl.arange(0, BLOCK_HEADDIM) out_ptrs = ( Out + off_b * stride_ob - + off_h * stride_oh + + off_hq * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) ) if EVEN_M: @@ -281,10 +326,10 @@ def _bwd_preprocess_do_o_dot( off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads - # initialize offsets + # Initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) - # load + # Load o o = tl.load( Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), @@ -300,42 +345,10 @@ def _bwd_preprocess_do_o_dot( other=0.0, ).to(tl.float32) delta = tl.sum(o * do, axis=1) - # write-back + # Write back tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) -@triton.jit -def _bwd_store_dk_dv( - dk_ptrs, - dv_ptrs, - dk, - dv, - offs_n, - offs_d, - seqlen_k, - headdim, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, -): - # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, - # if we just call tl.store(dv_ptrs), there's a race condition - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) - else: - tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) - tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) - tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) - else: - tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - - @triton.jit def _bwd_kernel_one_col_block( start_n, @@ -365,32 +378,37 @@ def _bwd_kernel_one_col_block( seqlen_q, seqlen_k, headdim, - ATOMIC_ADD: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_MASK: tl.constexpr, + HAS_BIAS: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + ATOMIC_ADD: tl.constexpr, + ACCUM_DBIAS: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M - # initialize row/col offsets + # Initialize row/col offsets offs_qm = begin_m + tl.arange(0, BLOCK_M) offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) - # initialize pointers to value-like data + # Initialize pointers to value-like data q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) - m_ptrs = Mask + (offs_qm[:, None] * stride_mm + offs_n[None, :]) - b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + if HAS_MASK: + m_ptrs = Mask + (offs_qm[:, None] * stride_mm + offs_n[None, :]) + if HAS_BIAS: + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) db_ptrs = DBias + (offs_qm[:, None] * stride_dbm + offs_n[None, :]) - # initialize dv and dk + # Initialize dv and dk dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) # There seems to be some problem with Triton pipelining that makes results wrong for @@ -399,24 +417,25 @@ def _bwd_kernel_one_col_block( if begin_m >= seqlen_q: dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - _bwd_store_dk_dv( - dk_ptrs, - dv_ptrs, - dk, - dv, - offs_n, - offs_d, - seqlen_k, - headdim, - EVEN_M=EVEN_M, - EVEN_N=EVEN_N, - EVEN_HEADDIM=EVEN_HEADDIM, - ) + + if EVEN_N: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) return - # k and v stay in SRAM throughout - # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, - # if we just call tl.load(k_ptrs), we get the wrong output! - if EVEN_N & EVEN_M: + + # Load k and v, them will stay in SRAM throughout + if EVEN_N: if EVEN_HEADDIM: k = tl.load(k_ptrs) v = tl.load(v_ptrs) @@ -434,218 +453,213 @@ def _bwd_kernel_one_col_block( v = tl.load( v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 ) - # loop over rows + + # Scale k + k = (k * softmax_scale).to(k.dtype) + + # Initialize accumulator for dbias if needed + acc_dbias = tl.zeros([BLOCK_N], dtype=tl.float32) if (HAS_BIAS and ACCUM_DBIAS) else None + + # Loop over q and update accumulators num_block_m = tl.cdiv(seqlen_q, BLOCK_M) for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) - if EVEN_M & EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - else: - q = tl.load( - q_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ) - # recompute p = softmax(acc_s, dim=-1).T - acc_s = tl.dot(q, tl.trans(k)) - - tl.debug_barrier() - # Load mask - if EVEN_M & EVEN_N: - mask = tl.load(m_ptrs) - else: - mask = tl.load( - m_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), - other=0.0, - ) - - # Trying to combine the two masks seem to make the result wrong - # Apply sequence length mask - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - acc_s = tl.where(offs_n[None, :] < seqlen_k, acc_s, float("-inf")) - # Apply causal mask - if IS_CAUSAL: - acc_s = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), acc_s, float("-inf")) - # Apply dynamic mask - acc_s = tl.where(mask > 0.0, acc_s, float("-inf")) - - tl.debug_barrier() # Race condition otherwise - # Load bias - if EVEN_M & EVEN_N: - bias = tl.load( - b_ptrs, - mask=(mask > 0.0), - other=0.0, - ).to(tl.float32) - else: - bias = tl.load( - b_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) & (mask > 0.0), - other=0.0, - ).to(tl.float32) - acc_s = acc_s * softmax_scale + bias - # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. - # Also wrong for headdim=64. - if not (EVEN_M & EVEN_HEADDIM): - tl.debug_barrier() - lse_i = tl.load(LSE + offs_m_curr) - p = tl.exp(acc_s - lse_i[:, None]) - # compute dv - # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call - # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs - # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, - # the output is correct. - if EVEN_M & EVEN_HEADDIM: - do = tl.load(do_ptrs) - else: - # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. - do = tl.load( - do_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ) - # if EVEN_M: - # if EVEN_HEADDIM: - # do = tl.load(do_ptrs) - # else: - # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - # else: - # if EVEN_HEADDIM: - # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - # else: - # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) - # & (offs_d[None, :] < headdim), other=0.0) - dv += tl.dot(tl.trans(p.to(do.dtype)), do) - # compute dp = dot(v, do) - # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. - # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True - # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False - if not (EVEN_M & EVEN_HEADDIM): - tl.debug_barrier() - dp = tl.dot(do, tl.trans(v)) - # There's a race condition for headdim=48 - if not EVEN_HEADDIM: - tl.debug_barrier() - # compute dbias = p * (dp - delta[:, None]) and ds = dbias * softmax_scale - # Putting the subtraction after the dp matmul (instead of before) is slightly faster - Di = tl.load(D + offs_m_curr) - # Converting ds to q.dtype here reduces register pressure and makes it much faster - # for BLOCK_HEADDIM=128 - dbias = (p * (dp - Di[:, None])) - ds = (dbias * softmax_scale).to(q.dtype) - # dbias = tl.where(mask > 0.0, dbias, 0.0) - # ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) - if not (EVEN_M & EVEN_N): - tl.debug_barrier() - if not ATOMIC_ADD: + + if HAS_MASK: + # Load mask if EVEN_M & EVEN_N: - tl.store( - db_ptrs, - dbias - ) + mask = tl.load(m_ptrs) else: - tl.store( - db_ptrs, - dbias, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) + mask = tl.load( + m_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + other=False, ) + + # Check if any element in mask is non-zero + any_active = tl.reduce_or(mask, axis=None) else: - if EVEN_M & EVEN_N: - tl.atomic_add( - db_ptrs, - dbias - ) - else: - tl.atomic_add( - db_ptrs, - dbias, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) - ) - # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds), q) - # compute dq - if not ( - EVEN_M & EVEN_HEADDIM - ): # Otherewise there's a race condition - tl.debug_barrier() - if not ATOMIC_ADD: - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds, k) - tl.store(dq_ptrs, dq, eviction_policy="evict_last") + any_active = True + + # Skip this iteration if no active elements + if any_active: + # Load q + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) else: if EVEN_HEADDIM: - dq = tl.load( - dq_ptrs, - mask=offs_m_curr[:, None] < seqlen_q, - other=0.0, - eviction_policy="evict_last", - ) - dq += tl.dot(ds, k) - tl.store( - dq_ptrs, - dq, - mask=offs_m_curr[:, None] < seqlen_q, - eviction_policy="evict_last", - ) + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) else: - dq = tl.load( - dq_ptrs, + q = tl.load( + q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, - eviction_policy="evict_last", - ) - dq += tl.dot(ds, k) - tl.store( - dq_ptrs, - dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - eviction_policy="evict_last", ) - else: # If we're parallelizing across the seqlen_k dimension - dq = tl.dot(ds, k) - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - tl.atomic_add(dq_ptrs, dq) + + if HAS_BIAS: + # Load bias + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load( + b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + acc_s = bias else: - if EVEN_HEADDIM: - tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + # Compute acc_s + acc_s += tl.dot(q, tl.trans(k)) + + # Apply masks + # Trying to combine the three masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where(offs_n[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + acc_s += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float("-inf")) + if HAS_MASK: + acc_s += tl.where(mask, 0, float("-inf")) + + lse_i = tl.load(LSE + offs_m_curr) + # p = tl.exp(acc_s - lse_i[:, None]) + p = tl.exp(acc_s - tl.where(lse_i > float("-inf"), lse_i, 0.0)[:, None]) + + # Load do + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # There's a race condition if we just use m_mask and not d_mask. + do = tl.load( + do_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # Compute dv + dv += tl.dot(tl.trans(p.to(do.dtype)), do) + + # Compute dp + dp = tl.dot(do, tl.trans(v)) + + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + + # Compute ds + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None])).to(q.dtype) + + # Write back + if not (EVEN_M & EVEN_N): + tl.debug_barrier() + if HAS_BIAS: + if ACCUM_DBIAS: + acc_dbias += tl.sum(ds, axis=0) else: - tl.atomic_add( - dq_ptrs, - dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - ) - # increment pointers - do_ptrs += BLOCK_M * stride_dom - dq_ptrs += BLOCK_M * stride_dqm - db_ptrs += BLOCK_M * stride_dbm - q_ptrs += BLOCK_M * stride_qm - m_ptrs += BLOCK_M * stride_mm - b_ptrs += BLOCK_M * stride_bm - - # write-back + if EVEN_M & EVEN_N: + tl.store( + db_ptrs, + ds, + ) + else: + tl.store( + db_ptrs, + ds, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + ) + + # Compute dk + dk += tl.dot(tl.trans(ds), q) + + # Compute dq + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k).to(ds.dtype) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load( + dq_ptrs, + mask=offs_m_curr[:, None] < seqlen_q, + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k).to(ds.dtype) + tl.store( + dq_ptrs, + dq, + mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last", + ) + else: + dq = tl.load( + dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k).to(ds.dtype) + tl.store( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last", + ) + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k).to(ds.dtype) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + ) + + # Increment pointers + do_ptrs += BLOCK_M * stride_dom + dq_ptrs += BLOCK_M * stride_dqm + if HAS_BIAS: + db_ptrs += BLOCK_M * stride_dbm + q_ptrs += BLOCK_M * stride_qm + if HAS_MASK: + m_ptrs += BLOCK_M * stride_mm + if HAS_BIAS: + b_ptrs += BLOCK_M * stride_bm + + # Scale dk + dk = (dk * softmax_scale).to(dk.dtype) + + # Write back dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - _bwd_store_dk_dv( - dk_ptrs, - dv_ptrs, - dk, - dv, - offs_n, - offs_d, - seqlen_k, - headdim, - EVEN_M=EVEN_M, - EVEN_N=EVEN_N, - EVEN_HEADDIM=EVEN_HEADDIM, - ) + if HAS_BIAS and ACCUM_DBIAS: + if EVEN_N: + tl.store(DBias + offs_n, acc_dbias) + else: + tl.store(DBias + offs_n, acc_dbias, mask=(offs_n < seqlen_k)) + + if EVEN_N: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) def init_to_zero(names): @@ -665,26 +679,21 @@ def init_func(nargs): num_stages=1, pre_hook=init_to_zero(["DQ", "DBias"]), ), - # triton.Config( - # {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, - # num_warps=8, - # num_stages=1, - # pre_hook=init_to_zero(["DQ", "DBias"]), - # ), - # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now - # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero(["DQ", "DBias"]), + ), ], - key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "BLOCK_HEADDIM"], + key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "HAS_MASK", "HAS_BIAS", "HAS_INDICE", "BLOCK_HEADDIM"], ) @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + "ACCUM_DBIAS": lambda args: args["HAS_BIAS"] and (args["stride_dbm"] == 0) and (args["seqlen_q"] > 1), } ) @triton.jit @@ -733,6 +742,10 @@ def _bwd_kernel( stride_dbh, stride_dbm, nheads, + nheads_k, + nheads_mask, + nheads_bias, + h_h_k_ratio, seqlen_q, seqlen_k, seqlen_q_rounded, @@ -740,31 +753,54 @@ def _bwd_kernel( CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, IS_CAUSAL: tl.constexpr, + HAS_MASK: tl.constexpr, + HAS_BIAS: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + ACCUM_DBIAS: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): off_hb = tl.program_id(1) off_b = off_hb // nheads - off_h = off_hb % nheads - # offset pointers for batch/head - Q += off_b * stride_qb + off_h * stride_qh - K += off_b * stride_kb + off_h * stride_kh - V += off_b * stride_vb + off_h * stride_vh - Mask += off_b * stride_mb + off_h * stride_mh - Bias += off_b * stride_bb + off_h * stride_bh - DO += off_b * stride_dob + off_h * stride_doh - DQ += off_b * stride_dqb + off_h * stride_dqh - DK += off_b * stride_dkb + off_h * stride_dkh - DV += off_b * stride_dvb + off_h * stride_dvh - DBias += off_b * stride_dbb + off_h * stride_dbh - # pointer to row-wise quantities in value-like data + off_hq = off_hb % nheads + off_hk = off_hq // h_h_k_ratio + if HAS_MASK: + if nheads_mask == 1: + off_hmask = 0 + elif nheads_mask == nheads_k: + off_hmask = off_hk + else: + off_hmask = off_hq + if HAS_BIAS: + if nheads_bias == 1: + off_hbbias = 0 + elif nheads_bias == nheads_k: + off_hbbias = off_hk + else: + off_hbbias = off_hq + + # Advance offset pointers for batch and head + Q += off_b * stride_qb + off_hq * stride_qh + K += off_b * stride_kb + off_hk * stride_kh + V += off_b * stride_vb + off_hk * stride_vh + if HAS_MASK: + Mask += off_b * stride_mb + off_hmask * stride_mh + if HAS_BIAS: + Bias += off_b * stride_bb + off_hbbias * stride_bh + DO += off_b * stride_dob + off_hq * stride_doh + DQ += off_b * stride_dqb + off_hq * stride_dqh + DK += off_b * stride_dkb + off_hq * stride_dkh + DV += off_b * stride_dvb + off_hq * stride_dvh + if HAS_BIAS: + DBias += off_b * stride_dbb + off_hq * stride_dbh + # Advance pointer to row-wise quantities in value-like data D += off_hb * seqlen_q_rounded LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: num_block_n = tl.cdiv(seqlen_k, BLOCK_N) for start_n in range(0, num_block_n): @@ -796,12 +832,15 @@ def _bwd_kernel( seqlen_q, seqlen_k, headdim, - ATOMIC_ADD=False, IS_CAUSAL=IS_CAUSAL, + HAS_MASK=HAS_MASK, + HAS_BIAS=HAS_BIAS, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, + ATOMIC_ADD=False, + ACCUM_DBIAS=ACCUM_DBIAS, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) @@ -835,52 +874,59 @@ def _bwd_kernel( seqlen_q, seqlen_k, headdim, - ATOMIC_ADD=True, IS_CAUSAL=IS_CAUSAL, + HAS_MASK=HAS_MASK, + HAS_BIAS=HAS_BIAS, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, + ATOMIC_ADD=True, + ACCUM_DBIAS=ACCUM_DBIAS, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) -def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False): +def _flash_dmattn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False): # shape constraints batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, _, _ = k.shape - assert k.shape == (batch, seqlen_k, nheads, d) - assert v.shape == (batch, seqlen_k, nheads, d) - assert d <= 128, "FlashAttention only support head dimensions up to 128" + _, seqlen_k, nheads_k, _ = k.shape + + assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" + assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" assert q.is_cuda and k.is_cuda and v.is_cuda - assert mask.shape == (batch, nheads, seqlen_q, seqlen_k), f"mask shape {mask.shape} does not match expected shape {(batch, nheads, seqlen_q, seqlen_k)}" - assert mask.dtype in [torch.float16, torch.bfloat16, torch.float32], "mask must be fp16, bf16, or fp32" - assert mask.is_cuda, "mask must be on CUDA" - if mask.stride(-1) != 1: - mask = mask.contiguous() + has_mask = mask is not None + if has_mask: + assert mask.dtype == torch.bool, "Only support bool" + assert mask.is_cuda + nheads_mask = mask.shape[1] + else: + nheads_mask = 1 + mask = torch.empty(0, device=q.device, dtype=torch.bool) - assert bias.dtype in [q.dtype, torch.float], f"bias dtype {bias.dtype} must match q dtype {q.dtype} or be float" - assert bias.is_cuda, "bias must be on CUDA" - assert bias.dim() == 4, f"bias must be 4D, got {bias.dim()}D" - assert bias.shape == (batch, nheads, seqlen_q, seqlen_k), f"bias shape {bias.shape} must be (batch={batch}, nheads={nheads}, seqlen_q={seqlen_q}, seqlen_k={seqlen_k})" - if bias.stride(-1) != 1: - bias = bias.contiguous() + has_bias = bias is not None + if has_bias: + assert bias.dtype == q.dtype, "Only support fp16 and bf16" + assert bias.is_cuda + nheads_bias = bias.shape[1] + else: + nheads_bias = 1 + bias = torch.empty(0, device=q.device, dtype=q.dtype) softmax_scale = softmax_scale or 1.0 / math.sqrt(d) seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) o = torch.empty_like(q) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - BLOCK_M = 128 - BLOCK_N = 64 - num_warps = 4 if d <= 64 else 8 + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 if d <= 64 else 8 grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) _fwd_kernel[grid]( q, @@ -890,7 +936,6 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False bias, o, lse, - tmp, softmax_scale, q.stride(0), q.stride(2), @@ -901,16 +946,20 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False v.stride(0), v.stride(2), v.stride(1), - mask.stride(0), - mask.stride(1), - mask.stride(2), - bias.stride(0), - bias.stride(1), - bias.stride(2), + ((0 if (has_mask and mask.shape[0] == 1) else (mask.stride(0) if has_mask else 0))), + ((0 if (has_mask and mask.shape[1] == 1) else (mask.stride(1) if has_mask else 0))), + ((0 if (has_mask and mask.shape[2] == 1) else (mask.stride(2) if has_mask else 0))), + ((0 if (has_bias and bias.shape[0] == 1) else (bias.stride(0) if has_bias else 0))), + ((0 if (has_bias and bias.shape[1] == 1) else (bias.stride(1) if has_bias else 0))), + ((0 if (has_bias and bias.shape[2] == 1) else (bias.stride(2) if has_bias else 0))), o.stride(0), o.stride(2), o.stride(1), nheads, + nheads_k, + nheads_mask, + nheads_bias, + nheads // nheads_k, seqlen_q, seqlen_k, seqlen_q_rounded, @@ -918,39 +967,49 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + # IS_CAUSAL=is_causal, HAS_MASK=has_mask, HAS_BIAS=has_bias, BLOCK_HEADDIM=d, is_causal, + has_mask, + has_bias, BLOCK_HEADDIM, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=1, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, ) return o, lse, softmax_scale # softmax_scale could have been updated -def _flash_attn_backward( +def _flash_dmattn_backward( do, q, k, v, mask, bias, o, lse, softmax_scale=None, is_causal=False ): # Make sure that the last dimension is contiguous if do.stride(-1) != 1: do = do.contiguous() batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, _, _ = k.shape - # assert d in {16, 32, 64, 128} - assert d <= 128 + _, seqlen_k, nheads_k, dk = k.shape + + assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" + assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + seqlen_k_rounded = math.ceil(seqlen_k / 128) * 128 assert lse.shape == (batch, nheads, seqlen_q_rounded) - assert mask.dtype in [q.dtype, torch.float] - assert mask.is_cuda - assert mask.dim() == 4 - assert mask.stride(-1) == 1 + has_mask = mask is not None + if has_mask: + assert mask.dtype == torch.bool, "Only support bool" + nheads_mask = mask.shape[1] + else: + nheads_mask = 1 + mask = torch.empty(0, device=q.device, dtype=torch.bool) - assert bias.dtype in [q.dtype, torch.float] - assert bias.is_cuda - assert bias.dim() == 4 - assert bias.stride(-1) == 1 + has_bias = bias is not None + if has_bias: + assert bias.dtype == q.dtype, "Only support fp16 and bf16" + nheads_bias = bias.shape[1] + else: + nheads_bias = 1 + bias = torch.empty(0, device=q.device, dtype=q.dtype) softmax_scale = softmax_scale or 1.0 / math.sqrt(d) # dq_accum = torch.zeros_like(q, dtype=torch.float32) @@ -959,7 +1018,24 @@ def _flash_attn_backward( # delta = torch.zeros_like(lse) dk = torch.empty_like(k) dv = torch.empty_like(v) - dbias = torch.empty_like(bias) + dbias = torch.empty_like(bias) if has_bias else torch.empty(0, device=q.device, dtype=q.dtype) + + dk_expanded = torch.empty(batch, seqlen_k, nheads, d, device=q.device, dtype=q.dtype) if nheads != nheads_k else dk + dv_expanded = torch.empty(batch, seqlen_k, nheads, d, device=q.device, dtype=q.dtype) if nheads != nheads_k else dv + if has_bias: + if ( + nheads_bias != nheads + or ((bias.shape[0] == 1) and (batch > 1)) + or ((bias.shape[-2] == 1) and (seqlen_q > 1)) + ): + if bias.shape[-2] == 1: + dbias_expanded = torch.zeros(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) + else: + dbias_expanded = torch.zeros(batch, nheads, seqlen_q, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) + else: + dbias_expanded = dbias + else: + dbias_expanded = dbias BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) @@ -996,9 +1072,9 @@ def _flash_attn_backward( bias, do, dq_accum, - dk, - dv, - dbias, + dk_expanded, + dv_expanded, + dbias_expanded, lse, delta, softmax_scale, @@ -1011,28 +1087,32 @@ def _flash_attn_backward( v.stride(0), v.stride(2), v.stride(1), - mask.stride(0), - mask.stride(1), - mask.stride(2), - bias.stride(0), - bias.stride(1), - bias.stride(2), + ((0 if (has_mask and mask.shape[0] == 1) else (mask.stride(0) if has_mask else 0))), + ((0 if (has_mask and mask.shape[1] == 1) else (mask.stride(1) if has_mask else 0))), + ((0 if (has_mask and mask.shape[2] == 1) else (mask.stride(2) if has_mask else 0))), + ((0 if (has_bias and bias.shape[0] == 1) else (bias.stride(0) if has_bias else 0))), + ((0 if (has_bias and bias.shape[1] == 1) else (bias.stride(1) if has_bias else 0))), + ((0 if (has_bias and bias.shape[2] == 1) else (bias.stride(2) if has_bias else 0))), do.stride(0), do.stride(2), do.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), - dk.stride(0), - dk.stride(2), - dk.stride(1), - dv.stride(0), - dv.stride(2), - dv.stride(1), - dbias.stride(0), - dbias.stride(1), - dbias.stride(2), + dk_expanded.stride(0), + dk_expanded.stride(2), + dk_expanded.stride(1), + dv_expanded.stride(0), + dv_expanded.stride(2), + dv_expanded.stride(1), + (dbias_expanded.stride(0) if has_bias else 0), + (dbias_expanded.stride(1) if has_bias else 0), + ((0 if (has_bias and bias.shape[-2] == 1) else (dbias_expanded.stride(2) if has_bias else 0))), nheads, + nheads_k, + nheads_mask, + nheads_bias, + nheads // nheads_k, seqlen_q, seqlen_k, seqlen_q_rounded, @@ -1040,16 +1120,45 @@ def _flash_attn_backward( seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + # IS_CAUSAL=is_causal, HAS_MASK=has_mask, HAS_BIAS=has_bias, BLOCK_HEADDIM=BLOCK_HEADDIM, is_causal, + has_mask, + has_bias, BLOCK_HEADDIM, # SEQUENCE_PARALLEL=False, - # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, # num_warps=num_warps, # num_stages=1, ) dq = dq_accum.to(q.dtype) - return dq, dk, dv, dbias + if nheads != nheads_k: + dk = dk_expanded.view(batch, seqlen_k, nheads_k, nheads // nheads_k, d).sum(dim=3) + dv = dv_expanded.view(batch, seqlen_k, nheads_k, nheads // nheads_k, d).sum(dim=3) + if has_bias: + if ( + nheads_bias != nheads + and bias.shape[0] == batch + and bias.shape[-2] == seqlen_q + ): + dbias = dbias_expanded.view(batch, nheads_bias, nheads // nheads_bias, seqlen_q, seqlen_k_rounded).sum(dim=2) + else: + if bias.shape[-2] == 1: + dbias_expanded = dbias_expanded.view(batch, nheads_bias, nheads // nheads_bias, 1, seqlen_k_rounded).sum(dim=2) + else: + dbias_expanded = dbias_expanded.view(batch, nheads_bias, nheads // nheads_bias, seqlen_q, seqlen_k_rounded).sum(dim=2) + if bias.shape[0] == 1: + dbias_expanded = dbias_expanded.sum(dim=0, keepdim=True) + dbias.copy_(dbias_expanded) + return dq, dk, dv, dbias if has_bias else None + + +def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def round_multiple(x, m): + return (x + m - 1) // m * m class FlashDMAttnFunc(torch.autograd.Function): @@ -1064,19 +1173,29 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa is_causal: bool, whether to apply causal masking softmax_scale: float, scaling factor for attention scores """ - batch, seqlen_q, nheads, _ = query.shape - _, seqlen_k, _, _ = key.shape - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_mask = torch.where(attn_mask, 1.0, 0.0) - else: - attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) - if attn_bias is None: - attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) # Make sure that the last dimension is contiguous - query, key, value, attn_mask, attn_bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [query, key, value, attn_mask, attn_bias]] - o, lse, ctx.softmax_scale = _flash_attn_forward( + query, key, value, attn_mask, attn_bias = [maybe_contiguous(x) for x in [query, key, value, attn_mask, attn_bias]] + + # Padding to multiple of 8 for 16-bit memory allocations + head_size_og = query.size(3) + if head_size_og % 8 != 0: + query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8]) + key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8]) + value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8]) + seqlen_k_rounded = round_multiple(key.shape[1], 128) + if attn_mask is not None and attn_mask.shape[-1] != seqlen_k_rounded: + if attn_mask.shape[-1] == 1: + attn_mask = attn_mask.expand(*attn_mask.shape[:-1], seqlen_k_rounded) + else: + attn_mask = torch.nn.functional.pad(attn_mask, [0, seqlen_k_rounded - attn_mask.shape[-1]]) + if attn_bias is not None and attn_bias.shape[-1] != seqlen_k_rounded: + if attn_bias.shape[-1] == 1: + attn_bias = attn_bias.expand(*attn_bias.shape[:-1], seqlen_k_rounded) + else: + attn_bias = torch.nn.functional.pad(attn_bias, [0, seqlen_k_rounded - attn_bias.shape[-1]]) + + o, lse, ctx.softmax_scale = _flash_dmattn_forward( query, key, value, @@ -1087,14 +1206,20 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa ) ctx.save_for_backward(query, key, value, o, lse, attn_mask, attn_bias) ctx.is_causal = is_causal + ctx.seqlen_k_bias_og = attn_bias.shape[-1] if attn_bias is not None else 0 return o @staticmethod def backward(ctx, do): query, key, value, o, lse, attn_mask, attn_bias = ctx.saved_tensors - assert not ctx.needs_input_grad[3], "FlashDMAttn does not support mask gradient yet" - dq, dk, dv, dbias = _flash_attn_backward( - do, + + head_size_og = do.size(3) + do_padded = do + if head_size_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) + + dq, dk, dv, dbias = _flash_dmattn_backward( + do_padded, query, key, value, @@ -1105,6 +1230,15 @@ def backward(ctx, do): softmax_scale=ctx.softmax_scale, is_causal=ctx.is_causal, ) + + # We could have padded the head dimension + dq = dq[..., : do.shape[-1]] + dk = dk[..., : do.shape[-1]] + dv = dv[..., : do.shape[-1]] + + if dbias is not None: + dbias = dbias[..., :key.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : key.shape[1]] + return dq, dk, dv, None, dbias, None, None