Skip to content

Commit f533c14

Browse files
committed
cleanup
1 parent dbcf080 commit f533c14

File tree

1 file changed

+2
-13
lines changed

1 file changed

+2
-13
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def forward_kernel(
6565
kv_block_indices,
6666
kv_block_mask,
6767
Out,
68-
M,
6968
Lse,
7069
softmax_scale,
7170
stride_qb,
@@ -118,8 +117,6 @@ def forward_kernel(
118117

119118
# maximum
120119

121-
m_ptrs = M + off_hb * seqlen_q_rounded + offs_m
122-
123120
m_i = tl.zeros([BLOCK], dtype = tl.float32) - float("inf")
124121

125122
# lse
@@ -189,7 +186,7 @@ def forward_kernel(
189186
l_ij = tl.sum(p, 1)
190187

191188
acc_o_scale = tl.exp(m_i - m_ij)
192-
acc_o = acc_o * acc_o_scale[:, None]
189+
acc_o *= acc_o_scale[:, None]
193190

194191
if EVEN_N & EVEN_M:
195192
if EVEN_HEADDIM:
@@ -302,12 +299,7 @@ def forward_kernel(
302299
# normalize accumulated out
303300

304301
acc_o_scale = tl.exp(m_i - lse_i)
305-
acc_o = acc_o * acc_o_scale[:, None]
306-
307-
# offsets
308-
309-
start_m = tl.program_id(0)
310-
offs_m = start_m * BLOCK + tl.arange(0, BLOCK)
302+
acc_o *= acc_o_scale[:, None]
311303

312304
# write back lse
313305

@@ -357,8 +349,6 @@ def flash_attn_forward(
357349

358350
lse = torch.empty((batch, nheads, seqlen_q_rounded), device = device, dtype = torch.float32)
359351

360-
m = torch.empty((batch, nheads, seqlen_q_rounded), device = device, dtype = torch.float32)
361-
362352
o = torch.empty_like(q)
363353

364354
BLOCK_HEADDIM = max(triton.next_power_of_2(dim), 16)
@@ -372,7 +362,6 @@ def flash_attn_forward(
372362
kv_block_indices,
373363
kv_block_mask,
374364
o,
375-
m,
376365
lse,
377366
softmax_scale,
378367
q.stride(0),

0 commit comments

Comments
 (0)