Skip to content

Commit 1a6019d

Browse files
committed
found issue with intermittent backwards error, get e2e train script for triton nsa
1 parent f6515e0 commit 1a6019d

File tree

3 files changed

+230
-4
lines changed

3 files changed

+230
-4
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -975,9 +975,9 @@ def flash_attn_backward(
975975

976976
softmax_scale = dim ** -0.5
977977

978-
dq_accum = torch.empty_like(q, dtype = torch.float32)
979-
dk_accum = torch.empty_like(k, dtype = torch.float32)
980-
dv_accum = torch.empty_like(v, dtype = torch.float32)
978+
dq_accum = torch.zeros_like(q, dtype = torch.float32)
979+
dk_accum = torch.zeros_like(k, dtype = torch.float32)
980+
dv_accum = torch.zeros_like(v, dtype = torch.float32)
981981

982982
# delta = torch.zeros_like(lse)
983983

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.48"
3+
version = "0.0.49"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

train_triton_nsa.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import math
2+
import gzip
3+
import random
4+
from tqdm import tqdm
5+
import numpy as np
6+
7+
import torch
8+
from torch.optim import Adam
9+
from torch import Tensor
10+
from torch.utils.data import DataLoader, Dataset
11+
12+
from native_sparse_attention_pytorch.transformer import Transformer
13+
14+
from native_sparse_attention_pytorch.compress_networks import (
15+
ConvLinearCompress,
16+
AttentionPool,
17+
GroupedMLP
18+
)
19+
20+
# constants
21+
22+
NUM_BATCHES = int(1e5)
23+
BATCH_SIZE = 4
24+
GRAD_ACCUM_EVERY = 4
25+
LEARNING_RATE = 1e-4
26+
VALIDATE_EVERY = 100
27+
PRIME_LENGTH = 64
28+
SHOULD_GENERATE = False
29+
GENERATE_EVERY = 500
30+
GENERATE_LENGTH = 512
31+
SEQ_LEN = 512
32+
HEADS = 8
33+
KV_HEADS = 8
34+
35+
USE_SPARSE_ATTN = True
36+
USE_TRITON_NSA = True
37+
USE_FLEX_FOR_FINE_SELECTION = False # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
38+
QUERY_HEADS_SHARE_SELECTION = False # if set to False, each query head can look at a different segment of their corresponding key / value head in GQA
39+
40+
# sparse attention related
41+
42+
SLIDING_WINDOW_SIZE = 32
43+
COMPRESS_BLOCK_SIZE = 16
44+
45+
FINE_BLOCK_SIZE = 16
46+
NUM_FINE_SELECTED = 1
47+
48+
INTERPOLATED_IMPORTANCE_SCORE = False
49+
USE_DIFF_TOPK = True
50+
51+
# experiment related
52+
53+
PROJECT_NAME = 'native-sparse-attention'
54+
RUN_NAME = 'baseline' if not USE_SPARSE_ATTN else f'sparse-attn: compress size {COMPRESS_BLOCK_SIZE} | fine size {FINE_BLOCK_SIZE} | {NUM_FINE_SELECTED} selected'
55+
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
56+
57+
# helpers
58+
59+
def exists(v):
60+
return v is not None
61+
62+
def cycle(loader):
63+
while True:
64+
for data in loader:
65+
yield data
66+
67+
def decode_token(token):
68+
return str(chr(max(32, token)))
69+
70+
def decode_tokens(tokens):
71+
return "".join(list(map(decode_token, tokens)))
72+
73+
# sampling helpers
74+
75+
def log(t, eps = 1e-20):
76+
return torch.log(t.clamp(min = eps))
77+
78+
def gumbel_noise(t):
79+
noise = torch.zeros_like(t).uniform_(0, 1)
80+
return -log(-log(noise))
81+
82+
def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True):
83+
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim)
84+
85+
def top_k(logits, thres = 0.9):
86+
k = math.ceil((1 - thres) * logits.shape[-1])
87+
val, ind = torch.topk(logits, k)
88+
probs = torch.full_like(logits, float('-inf'))
89+
probs.scatter_(-1, ind, val)
90+
return probs
91+
92+
def base_decoding(
93+
net,
94+
prompt: Tensor,
95+
seq_len: int,
96+
temperature = 1.,
97+
filter_thres = 0.9,
98+
):
99+
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
100+
sample_num_times = max(0, seq_len - prompt_seq_len)
101+
102+
for _ in tqdm(range(sample_num_times)):
103+
logits = net(out, disable_flex = True)
104+
105+
logits = logits[:, -1]
106+
logits = top_k(logits, thres = filter_thres)
107+
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
108+
109+
out = torch.cat((out, sample), dim = -1)
110+
111+
return out[..., prompt_seq_len:]
112+
113+
# model
114+
115+
model = Transformer(
116+
num_tokens = 256,
117+
dim = 512,
118+
depth = 6,
119+
heads = HEADS,
120+
dim_head = 64,
121+
kv_heads = KV_HEADS,
122+
use_sparse_attn = USE_SPARSE_ATTN,
123+
use_flex_sliding_window = True,
124+
use_triton_fine_selection = USE_TRITON_NSA,
125+
use_flex_fine_selection = USE_FLEX_FOR_FINE_SELECTION,
126+
sparse_attn_kwargs = dict(
127+
sliding_window_size = SLIDING_WINDOW_SIZE,
128+
compress_block_size = COMPRESS_BLOCK_SIZE,
129+
compress_mlp = GroupedMLP(
130+
dim_head = 64,
131+
compress_block_size = COMPRESS_BLOCK_SIZE,
132+
heads = KV_HEADS,
133+
),
134+
selection_block_size = FINE_BLOCK_SIZE,
135+
num_selected_blocks = NUM_FINE_SELECTED,
136+
use_diff_topk = USE_DIFF_TOPK,
137+
interpolated_importance_score = INTERPOLATED_IMPORTANCE_SCORE,
138+
query_heads_share_selected_kv = QUERY_HEADS_SHARE_SELECTION
139+
)
140+
).cuda()
141+
142+
# prepare enwik8 data
143+
144+
with gzip.open('./data/enwik8.gz') as file:
145+
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
146+
np_train, np_valid = np.split(data, [int(90e6)])
147+
data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)
148+
149+
class TextSamplerDataset(Dataset):
150+
def __init__(self, data, seq_len):
151+
super().__init__()
152+
self.data = data
153+
self.seq_len = seq_len
154+
155+
def __len__(self):
156+
return self.data.size(0) // self.seq_len
157+
158+
def __getitem__(self, index):
159+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
160+
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
161+
return full_seq.cuda()
162+
163+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
164+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
165+
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE)
166+
val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE)
167+
168+
# optimizer
169+
170+
optim = Adam(model.parameters(), lr = LEARNING_RATE)
171+
172+
train_loader = cycle(train_loader)
173+
val_loader = cycle(val_loader)
174+
175+
# wandb experiment tracker
176+
177+
import wandb
178+
wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
179+
wandb.run.name = RUN_NAME
180+
wandb.run.save()
181+
182+
# training
183+
184+
for i in tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
185+
model.train()
186+
187+
for _ in range(GRAD_ACCUM_EVERY):
188+
data = next(train_loader)
189+
190+
loss = model(data, return_loss = True)
191+
192+
(loss / GRAD_ACCUM_EVERY).backward()
193+
194+
wandb.log(dict(loss = loss.item()), step = i)
195+
print(f"training loss: {loss.item():.3f}")
196+
197+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
198+
199+
optim.step()
200+
optim.zero_grad()
201+
202+
if i % VALIDATE_EVERY == 0:
203+
model.eval()
204+
with torch.no_grad():
205+
valid_data = next(val_loader)
206+
207+
loss = model(valid_data, return_loss = True)
208+
wandb.log(dict(valid_loss = loss.item()), step = i)
209+
print(f"validation loss: {loss.item():.3f}")
210+
211+
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
212+
model.eval()
213+
214+
inp = random.choice(val_dataset)[:PRIME_LENGTH]
215+
inp = inp.cuda()
216+
217+
prime = decode_tokens(inp)
218+
print(f"\n{prime}\n")
219+
220+
prompt = inp[None, ...]
221+
222+
sampled = base_decoding(model, prompt, GENERATE_LENGTH)
223+
224+
base_decode_output = decode_tokens(sampled[0])
225+
226+
print(f"\n{base_decode_output}\n")

0 commit comments

Comments
 (0)