Skip to content

Commit 04603d9

Browse files
committed
gqa works for main block causal, but broken for fine selection pathway, hack on it another day
1 parent aa37f90 commit 04603d9

File tree

4 files changed

+54
-25
lines changed

4 files changed

+54
-25
lines changed

native_sparse_attention_pytorch/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def forward(
183183
ids,
184184
return_loss = False,
185185
disable_flex = False,
186-
disable_triton_kernel = True
186+
disable_triton_kernel = False
187187
):
188188
if return_loss:
189189
ids, labels = ids[:, :-1], ids[:, 1:]

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -802,32 +802,50 @@ def backward_kernel_one_col_block(
802802
block_k = tl.load(block_k_ptrs)
803803
block_v = tl.load(block_v_ptrs)
804804

805-
q_expanded = tl.expand_dims(q, 1)
806-
q_expanded = tl.broadcast_to(q_expanded, (BLOCK, 16, BLOCK_HEADDIM))
805+
q_expanded = q.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK_HEADDIM)
806+
q_expanded = q_expanded.permute(1, 0, 2)
807+
q_expanded = tl.expand_dims(q_expanded, 2)
808+
q_expanded = tl.broadcast_to(q_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM))
809+
q_expanded = q_expanded.reshape(BLOCK, 16, BLOCK_HEADDIM)
807810

808811
block_k_permuted = tl.permute(block_k, (0, 2, 1))
809812
block_qk = tl.dot(q_expanded, block_k_permuted)
810813

811-
qk = tl.sum(block_qk, 1) / 16.
812-
qk += tl.where(block_masks[:, None], 0, float("-inf"))
814+
block_qk = block_qk.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK)
815+
qk = tl.sum(block_qk, 2) / QUERY_EXPAND_DIM
816+
qk = qk.permute(1, 0, 2)
817+
818+
qk += tl.where(block_masks[None, :, None], 0, float("-inf"))
819+
820+
qk = qk.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK)
813821

814822
p = tl.exp(qk * softmax_scale - lse_i[:, None])
815823

816824
# take care of block dv
817825

818826
block_dv = p.to(do.dtype)[:, :, None] * do[:, None, :]
819-
block_dv = tl.where(block_masks[:, None, None], block_dv, 0.)
827+
828+
block_dv = block_dv.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK, BLOCK_HEADDIM)
829+
block_dv = tl.sum(block_dv, 0)
820830

821831
tl.atomic_add(block_dv_ptrs, block_dv, sem = 'relaxed')
822832

823833
# get dp
824834

825-
do_expanded = tl.expand_dims(do, 1)
826-
do_expanded = tl.broadcast_to(do_expanded, (BLOCK, 16, BLOCK_HEADDIM))
835+
do_expanded = do.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK_HEADDIM)
836+
do_expanded = do_expanded.permute(1, 0, 2)
837+
do_expanded = tl.expand_dims(do_expanded, 2)
838+
do_expanded = tl.broadcast_to(do_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM))
839+
do_expanded = do_expanded.reshape(BLOCK, 16, BLOCK_HEADDIM)
840+
827841
block_v = tl.permute(block_v, (0, 2, 1))
828842

829843
dp = tl.dot(do_expanded, block_v)
830-
dp = tl.sum(dp, 1) / 16.
844+
845+
dp = dp.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK)
846+
dp = tl.sum(dp, 2) / QUERY_EXPAND_DIM
847+
dp = dp.permute(1, 0, 2)
848+
dp = dp.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK)
831849

832850
# ds
833851

@@ -837,15 +855,25 @@ def backward_kernel_one_col_block(
837855
# block dk
838856

839857
block_dk = ds[:, :, None] * q[:, None, :]
858+
block_dk = block_dk.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK, BLOCK_HEADDIM)
859+
block_dk = tl.sum(block_dk, 0)
840860

841861
tl.atomic_add(block_dk_ptrs, block_dk, sem = 'relaxed')
842862

843863
# block dq
844864

845-
ds_expanded = tl.expand_dims(ds, 1)
846-
ds_expanded = tl.broadcast_to(ds_expanded, (BLOCK, 16, BLOCK))
865+
ds_expanded = ds.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK)
866+
ds_expanded = ds_expanded.permute(1, 0, 2)
867+
ds_expanded = tl.expand_dims(ds_expanded, 2)
868+
ds_expanded = tl.broadcast_to(ds_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK))
869+
ds_expanded = ds_expanded.reshape(BLOCK, 16, BLOCK)
870+
847871
block_dq = tl.dot(ds_expanded, block_k)
848-
block_dq = tl.sum(block_dq, 1) / 16
872+
873+
block_dq = block_dq.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM)
874+
block_dq = tl.sum(block_dq, 2) / QUERY_EXPAND_DIM
875+
block_dq = block_dq.permute(1, 0, 2)
876+
block_dq = block_dq.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM)
849877

850878
dq += block_dq
851879

@@ -1194,9 +1222,7 @@ def backward(self, ctx, do, _):
11941222
out, lse, dq, dk, dv,
11951223
block_size = block_size
11961224
)
1197-
1198-
dk, dv = tuple(reduce(t, 'b (h g) ... -> b h ...', 'sum', g = head_groups) for t in (dk, dv))
1199-
1225+
12001226
return dq, dk, dv, None, None, None, None
12011227

12021228
_native_sparse_attend = NSA.apply
@@ -1208,6 +1234,8 @@ def native_sparse_attend(
12081234
fmask,
12091235
return_lse = False
12101236
):
1237+
assert divisible_by(fq.shape[-2], block_size)
1238+
12111239
out, lse = _native_sparse_attend(
12121240
fq, fk, fv,
12131241
block_size,

test_triton_nsa.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,12 @@ def regular_attend(
9393

9494
fine_block_size = 16
9595

96-
q = torch.randn(1, 2, 512, 64).cuda()
97-
k = torch.randn(1, 2, 512, 64).cuda()
98-
v = torch.randn(1, 2, 512, 64).cuda()
96+
q = torch.randn(2, 4, 512, 64).cuda()
97+
k = torch.randn(2, 2, 512, 64).cuda()
98+
v = torch.randn(2, 2, 512, 64).cuda()
9999

100-
indices = torch.zeros(1, 2, 512, 1).long().cuda()
101-
mask = torch.ones(1, 2, 512, 1).bool().cuda()
100+
indices = torch.zeros(2, 2, 512, 0).long().cuda()
101+
mask = torch.randint(0, 2, (2, 2, 512, 0)).bool().cuda()
102102

103103
# both regular and nsa pathways `r` and `n`
104104

train_triton_nsa.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,25 @@
2525
LEARNING_RATE = 1e-4
2626
VALIDATE_EVERY = 100
2727
PRIME_LENGTH = 64
28+
SHOULD_GENERATE = False
2829
GENERATE_EVERY = 500
2930
GENERATE_LENGTH = 512
3031
SEQ_LEN = 512
3132
HEADS = 8
32-
KV_HEADS = 8
33+
KV_HEADS = 4
3334

3435
USE_SPARSE_ATTN = True
3536
USE_TRITON_NSA = True
36-
USE_FLEX_FOR_FINE_SELECTION = False # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
37-
QUERY_HEADS_SHARE_SELECTION = False # if set to False, each query head can look at a different segment of their corresponding key / value head in GQA
37+
USE_FLEX_FOR_FINE_SELECTION = False # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
38+
QUERY_HEADS_SHARE_SELECTION = True # if set to False, each query head can look at a different segment of their corresponding key / value head in GQA
3839

3940
# sparse attention related
4041

4142
SLIDING_WINDOW_SIZE = 32
4243
COMPRESS_BLOCK_SIZE = 16
4344

4445
FINE_BLOCK_SIZE = 16
45-
NUM_FINE_SELECTED = 1
46+
NUM_FINE_SELECTED = 0
4647

4748
INTERPOLATED_IMPORTANCE_SCORE = False
4849
USE_DIFF_TOPK = True
@@ -211,7 +212,7 @@ def __getitem__(self, index):
211212
wandb.log(dict(valid_loss = loss.item()), step = i)
212213
print(f"validation loss: {loss.item():.3f}")
213214

214-
if i % GENERATE_EVERY == 0:
215+
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
215216
model.eval()
216217

217218
inp = random.choice(val_dataset)[:PRIME_LENGTH]

0 commit comments

Comments
 (0)