Skip to content

Commit 9d03015

Browse files
committed
something is wrong with lse with forward gqa, build some stuff for debugging
1 parent d5a76e1 commit 9d03015

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,7 +1153,7 @@ def forward(
11531153
block_size = block_size
11541154
)
11551155

1156-
fk, fv, selected_block_indices, fmask = tuple(repeat(t, 'b h ... -> b (h g) ...', g = head_groups) for t in (fk, fv, selected_block_indices, fmask))
1156+
fk, fv, selected_block_indices, fmask = tuple(repeat(t, 'b h ... -> b (h g) ...', g = head_groups).contiguous() for t in (fk, fv, selected_block_indices, fmask))
11571157

11581158
ctx.save_for_backward(fq, fk, fv, selected_block_indices, fmask, out, lse)
11591159

@@ -1162,10 +1162,10 @@ def forward(
11621162
head_groups
11631163
)
11641164

1165-
return out.type(dtype)
1165+
return out.type(dtype), lse
11661166

11671167
@classmethod
1168-
def backward(self, ctx, do):
1168+
def backward(self, ctx, do, _):
11691169
device = do.device
11701170

11711171
q, k, v, sel_block_indices, mask, out, lse = ctx.saved_tensors
@@ -1191,4 +1191,23 @@ def backward(self, ctx, do):
11911191

11921192
return dq, dk, dv, None, None, None, None
11931193

1194-
native_sparse_attend = NSA.apply
1194+
_native_sparse_attend = NSA.apply
1195+
1196+
def native_sparse_attend(
1197+
fq, fk, fv,
1198+
block_size,
1199+
selected_block_indices,
1200+
fmask,
1201+
return_lse = False
1202+
):
1203+
out, lse = _native_sparse_attend(
1204+
fq, fk, fv,
1205+
block_size,
1206+
selected_block_indices,
1207+
fmask,
1208+
)
1209+
1210+
if not return_lse:
1211+
return out
1212+
1213+
return out, lse

test_triton_nsa.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
def exists(v):
1010
return v is not None
1111

12+
def abs_diff(x, y):
13+
return (x - y).abs().amax()
14+
1215
def divisible_by(num, den):
1316
return (num % den) == 0
1417

@@ -17,6 +20,7 @@ def regular_attend(
1720
indices,
1821
mask,
1922
block_size,
23+
return_lse = False
2024
):
2125
q_heads, seq_len, kv_heads, device = q.shape[1], q.shape[-2], k.shape[1], q.device
2226
assert divisible_by(q_heads, kv_heads)
@@ -77,13 +81,19 @@ def regular_attend(
7781
else:
7882
out = einsum(attn, v, 'b h g w i j, b h w j d -> b h g w i d')
7983

80-
return rearrange(out, 'b h g w n d -> b (h g) (w n) d')
84+
out = rearrange(out, 'b h g w n d -> b (h g) (w n) d')
85+
86+
if not return_lse:
87+
return out
88+
89+
lse = sim.logsumexp(dim = -1)
90+
return out, rearrange(lse, 'b g h w n -> b (g h) (w n)')
8191

8292
# mock inputs
8393

8494
fine_block_size = 16
8595

86-
q = torch.randn(1, 4, 512, 64).cuda()
96+
q = torch.randn(1, 2, 512, 64).cuda()
8797
k = torch.randn(1, 2, 512, 64).cuda()
8898
v = torch.randn(1, 2, 512, 64).cuda()
8999

@@ -97,17 +107,18 @@ def regular_attend(
97107

98108
# regular forwards and backwards
99109

100-
out = regular_attend(rq, rk, rv, indices, mask, block_size = fine_block_size)
110+
out, rlse = regular_attend(rq, rk, rv, indices, mask, block_size = fine_block_size, return_lse = True)
101111
out.sum().backward()
102112

103113
# triton nsa forwards and backwards
104114

105-
nsa_out = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask)
115+
nsa_out, nlse = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask, return_lse = True)
106116
nsa_out.sum().backward()
107117

108118
# asserts
109119

110120
assert torch.allclose(out, nsa_out, atol = 1e-2)
121+
assert torch.allclose(rlse, nlse, atol = 1e-2)
111122

112123
assert torch.allclose(nv.grad, rv.grad, atol = 1e-2)
113124
assert torch.allclose(nk.grad, rk.grad, atol = 1e-2)

0 commit comments

Comments
 (0)