diff --git a/_collections/_portal_posts/2025-09-02-improving-triton-flashattention-performance-on-intel-gpu.md b/_collections/_portal_posts/2025-09-02-improving-triton-flashattention-performance-on-intel-gpu.md index 882f1f6..33df973 100644 --- a/_collections/_portal_posts/2025-09-02-improving-triton-flashattention-performance-on-intel-gpu.md +++ b/_collections/_portal_posts/2025-09-02-improving-triton-flashattention-performance-on-intel-gpu.md @@ -207,45 +207,45 @@ This is exactly what happens in the widespread case of [FlashAttention version 2 The FlashAttention v2 Forward pass algorithm in pseudo-code is: ```python -# Inputs : Q, K and V are 2D Matrices in Global Memory -def FlashAttention2_forward(Q, K, V): - O = torch.zeros_like(Q, requires_grad=True) - L = torch.zeros(Q.shape[:-1])[...,None] - - Q_BLOCKS = torch.split(Q, BLOCK_SHAPE) - K_BLOCKS = torch.split(K, BLOCK_SHAPE) - V_BLOCKS = torch.split(V, BLOCK_SHAPE) - - Tr = len(Q_BLOCKS) - Tc = len(K_BLOCKS) - - for i in range(Tr): - Qi = load(Q_BLOCKS[i]) # Load data from Global Memory to SRAM - Oi = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip - li = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip - mi = NEG_INF # No load required, Initialized on chip - - for j in range(Tc): - Kj = load(K_BLOCKS[j]) # Load data from Global Memory to SRAM - Vj = load(V_BLOCKS[j]) # Load data from Global Memory to SRAM - - KTj = Kj.transpose() - S_ij = matmul(Qi, KTj) - - P_ij, m_block_ij, mi_new, li_new = online_softmax(S_ij, mi, li) - - P_ij_Vj = matmul(P_ij, Vj) - Oij = (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj - - # update li and mi - li = li_new - mi = mi_new - - Oi = Oij / diag(li) - O.store(Oi, i) # Store data to Global Memory as the i-th block of O - L.store(li, i) # Store data to Global Memory as the i-th block of L - - return O, L +1 # Inputs : Q, K and V are 2D Matrices in Global Memory +2 def FlashAttention2_forward(Q, K, V): +3 O = torch.zeros_like(Q, requires_grad=True) +4 L = torch.zeros(Q.shape[:-1])[...,None] +5 +6 Q_BLOCKS = torch.split(Q, BLOCK_SHAPE) +7 K_BLOCKS = torch.split(K, BLOCK_SHAPE) +8 V_BLOCKS = torch.split(V, BLOCK_SHAPE) +9 +10 Tr = len(Q_BLOCKS) +11 Tc = len(K_BLOCKS) +12 +13 for i in range(Tr): +14 Qi = load(Q_BLOCKS[i]) # Load data from Global Memory to SRAM +15 Oi = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip +16 li = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip +17 mi = NEG_INF # No load required, Initialized on chip +18 +19 for j in range(Tc): +20 Kj = load(K_BLOCKS[j]) # Load data from Global Memory to SRAM +21 Vj = load(V_BLOCKS[j]) # Load data from Global Memory to SRAM +22 +23 KTj = Kj.transpose() +24 S_ij = matmul(Qi, KTj) +25 +26 P_ij, m_block_ij, mi_new, li_new = online_softmax(S_ij, mi, li) +27 +28 P_ij_Vj = matmul(P_ij, Vj) +29 Oij = (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj +30 +31 # update li and mi +32 li = li_new +33 mi = mi_new +34 +35 Oi = Oij / diag(li) +36 O.store(Oi, i) # Store data to Global Memory as the i-th block of O +37 L.store(li, i) # Store data to Global Memory as the i-th block of L +38 +39 return O, L ``` In the second version of the implementation of the FlashAttention model, the loop order has been reversed to promote @@ -253,9 +253,9 @@ data locality. As long as there is enough local memory (or registers) to contain all the needed data, this algorithm works fine and provides significant performance improvements compared to FlashAttention v1 (in the paper, the authors mention 2x faster for the Cutlass implementation and 1.3-1.5× faster in Triton on an Nvidia Ampere GPU A100). -Deployed on a GPU target, line 4-10 constitutes the computing kernel that is dispatched to a Thread Block/Work-Group ( +Deployed on a GPU target, line 13-37 constitutes the computing kernel that is dispatched to a Thread Block/Work-Group ( i.e. a SM/XeCore). -But as you can see, variable Q is loaded before the loop (line 4) and remains *live* across the loop. +As you can see, variable Q is loaded before the loop (line 14) and remains *live* across the loop. The long lifespan of variable Q is even more problematic in the causal variation of the FlashAttention implementation. The causal variation is defined in the paper as : @@ -264,7 +264,7 @@ The causal variation is defined in the paper as : The Triton implementation of FlashAttention v2 with causal mask is as follow: -```python {.line-numbers} +```python @triton.jit def _attn_fwd(Q_block_ptr, K_block_ptr, V_block_ptr, sm_scale, M, N_CTX: tl.constexpr, # BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #