@@ -111,10 +111,10 @@ def _flash_decoding_fwd_kernel(
111
111
m = tl .max (S_ij , 0 )
112
112
S_ij -= m
113
113
p_ij_hat = tl .exp (S_ij )
114
- l = tl .sum (p_ij_hat , 0 )
114
+ l_i = tl .sum (p_ij_hat , 0 )
115
115
p_ij_hat = p_ij_hat .to (v_cur_block .type .element_ty )
116
116
acc += tl .sum (v_cur_block * p_ij_hat [:, None ], 0 )
117
- acc = acc / l
117
+ acc = acc / l_i
118
118
119
119
offsets_mid_o = (
120
120
cur_token_idx * stride_mid_ot
@@ -126,8 +126,8 @@ def _flash_decoding_fwd_kernel(
126
126
offsets_mid_o_lse = (
127
127
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
128
128
)
129
- # logsumexp L ^(j) = m^(j) + log(l ^(j))
130
- tl .store (mid_o_lse + offsets_mid_o_lse , m + tl .log (l ))
129
+ # logsumexp l_i ^(j) = m^(j) + log(l_i ^(j))
130
+ tl .store (mid_o_lse + offsets_mid_o_lse , m + tl .log (l_i ))
131
131
132
132
133
133
# Triton 2.1.0
@@ -234,10 +234,10 @@ def _alibi_flash_decoding_fwd_kernel(
234
234
m = tl .max (S_ij , 0 )
235
235
S_ij -= m
236
236
p_ij_hat = tl .exp (S_ij )
237
- l = tl .sum (p_ij_hat , 0 )
237
+ l_i = tl .sum (p_ij_hat , 0 )
238
238
p_ij_hat = p_ij_hat .to (v_cur_block .type .element_ty )
239
239
acc += tl .sum (v_cur_block * p_ij_hat [:, None ], 0 )
240
- acc = acc / l
240
+ acc = acc / l_i
241
241
242
242
offsets_mid_o = (
243
243
cur_token_idx * stride_mid_ot
@@ -249,8 +249,8 @@ def _alibi_flash_decoding_fwd_kernel(
249
249
offsets_mid_o_lse = (
250
250
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
251
251
)
252
- # logsumexp L ^(j) = m^(j) + log(l ^(j))
253
- tl .store (mid_o_lse + offsets_mid_o_lse , m + tl .log (l ))
252
+ # logsumexp l_i ^(j) = m^(j) + log(l_i ^(j))
253
+ tl .store (mid_o_lse + offsets_mid_o_lse , m + tl .log (l_i ))
254
254
255
255
256
256
# Triton 2.1.0
@@ -290,7 +290,7 @@ def _flash_decoding_fwd_reduce_kernel(
290
290
# BLOCK_KV == BLOCK_SIZE for now. We might want to decrease the number of blocks of kv splitted.
291
291
kv_split_num = (cur_kv_seq_len + BLOCK_KV - 1 ) // BLOCK_KV
292
292
m_i = float ("-inf" ) # max logic
293
- l = 0.0 # sum exp
293
+ l_i = 0.0 # sum exp
294
294
acc = tl .zeros ([HEAD_DIM ], dtype = tl .float32 )
295
295
296
296
offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
@@ -304,10 +304,10 @@ def _flash_decoding_fwd_reduce_kernel(
304
304
lse -= m_ij
305
305
exp_logic = tl .exp (lse )
306
306
acc += exp_logic * mid_o_block
307
- l = scale * l + exp_logic
307
+ l_i = scale * l_i + exp_logic
308
308
m_i = m_ij
309
309
310
- acc = acc / l
310
+ acc = acc / l_i
311
311
offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
312
312
tl .store (O + offsets_O , acc .to (O .type .element_ty ))
313
313
return
0 commit comments