Skip to content

Commit bd38fe6

Browse files
[NFC] Fix code factors on inference triton kernels (#5743)
1 parent c2c8c9c commit bd38fe6

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

colossalai/kernel/triton/flash_decoding.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ def _flash_decoding_fwd_kernel(
111111
m = tl.max(S_ij, 0)
112112
S_ij -= m
113113
p_ij_hat = tl.exp(S_ij)
114-
l = tl.sum(p_ij_hat, 0)
114+
l_i = tl.sum(p_ij_hat, 0)
115115
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
116116
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
117-
acc = acc / l
117+
acc = acc / l_i
118118

119119
offsets_mid_o = (
120120
cur_token_idx * stride_mid_ot
@@ -126,8 +126,8 @@ def _flash_decoding_fwd_kernel(
126126
offsets_mid_o_lse = (
127127
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
128128
)
129-
# logsumexp L^(j) = m^(j) + log(l^(j))
130-
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
129+
# logsumexp l_i^(j) = m^(j) + log(l_i^(j))
130+
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i))
131131

132132

133133
# Triton 2.1.0
@@ -234,10 +234,10 @@ def _alibi_flash_decoding_fwd_kernel(
234234
m = tl.max(S_ij, 0)
235235
S_ij -= m
236236
p_ij_hat = tl.exp(S_ij)
237-
l = tl.sum(p_ij_hat, 0)
237+
l_i = tl.sum(p_ij_hat, 0)
238238
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
239239
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
240-
acc = acc / l
240+
acc = acc / l_i
241241

242242
offsets_mid_o = (
243243
cur_token_idx * stride_mid_ot
@@ -249,8 +249,8 @@ def _alibi_flash_decoding_fwd_kernel(
249249
offsets_mid_o_lse = (
250250
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
251251
)
252-
# logsumexp L^(j) = m^(j) + log(l^(j))
253-
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
252+
# logsumexp l_i^(j) = m^(j) + log(l_i^(j))
253+
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i))
254254

255255

256256
# Triton 2.1.0
@@ -290,7 +290,7 @@ def _flash_decoding_fwd_reduce_kernel(
290290
# BLOCK_KV == BLOCK_SIZE for now. We might want to decrease the number of blocks of kv splitted.
291291
kv_split_num = (cur_kv_seq_len + BLOCK_KV - 1) // BLOCK_KV
292292
m_i = float("-inf") # max logic
293-
l = 0.0 # sum exp
293+
l_i = 0.0 # sum exp
294294
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
295295

296296
offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
@@ -304,10 +304,10 @@ def _flash_decoding_fwd_reduce_kernel(
304304
lse -= m_ij
305305
exp_logic = tl.exp(lse)
306306
acc += exp_logic * mid_o_block
307-
l = scale * l + exp_logic
307+
l_i = scale * l_i + exp_logic
308308
m_i = m_ij
309309

310-
acc = acc / l
310+
acc = acc / l_i
311311
offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
312312
tl.store(O + offsets_O, acc.to(O.type.element_ty))
313313
return

0 commit comments

Comments
 (0)