Skip to content

Commit ea4fb34

Browse files
committed
continue cleaning
1 parent 25a770a commit ea4fb34

File tree

1 file changed

+22
-32
lines changed

1 file changed

+22
-32
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,6 @@ def flash_attn_forward(
240240
q,
241241
k,
242242
v,
243-
o = None,
244-
m = None,
245-
lse = None,
246243
softmax_scale = None,
247244
remove_padding = False,
248245
block_size = 128
@@ -263,16 +260,11 @@ def flash_attn_forward(
263260

264261
seqlen_q_rounded = ceil(seqlen_q / 128) * 128
265262

266-
if not exists(lse):
267-
max_neg_value = -torch.finfo(torch.float32).max
268-
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
263+
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
269264

270-
if not exists(m):
271-
max_neg_value = -torch.finfo(torch.float32).max
272-
m = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
265+
m = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
273266

274-
if not exists(o):
275-
o = torch.empty_like(q)
267+
o = torch.empty_like(q)
276268

277269
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
278270
num_warps = 4 if d <= 64 else 8
@@ -779,7 +771,6 @@ def flash_attn_backward(
779771
dq,
780772
dk,
781773
dv,
782-
delta = None,
783774
softmax_scale = None,
784775
block_size = 128
785776
):
@@ -804,26 +795,25 @@ def flash_attn_backward(
804795

805796
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
806797

807-
if not exists(delta):
808-
delta = torch.empty_like(lse)
809-
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK"]), batch * nheads)
810-
_bwd_preprocess_do_o_dot[grid](
811-
o,
812-
do,
813-
delta,
814-
o.stride(0),
815-
o.stride(2),
816-
o.stride(1),
817-
do.stride(0),
818-
do.stride(2),
819-
do.stride(1),
820-
nheads,
821-
seqlen_q,
822-
seqlen_q_rounded,
823-
d,
824-
BLOCK = block_size,
825-
BLOCK_HEADDIM=BLOCK_HEADDIM,
826-
)
798+
delta = torch.empty_like(lse)
799+
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK"]), batch * nheads)
800+
_bwd_preprocess_do_o_dot[grid](
801+
o,
802+
do,
803+
delta,
804+
o.stride(0),
805+
o.stride(2),
806+
o.stride(1),
807+
do.stride(0),
808+
do.stride(2),
809+
do.stride(1),
810+
nheads,
811+
seqlen_q,
812+
seqlen_q_rounded,
813+
d,
814+
BLOCK = block_size,
815+
BLOCK_HEADDIM=BLOCK_HEADDIM,
816+
)
827817

828818
# BLOCK_M = 128
829819
# BLOCK_N = 64

0 commit comments

Comments
 (0)