Skip to content

Commit 2eb7cf6

Browse files
committed
able to generate samples in triton train script
1 parent 4d34b8a commit 2eb7cf6

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def __init__(
318318
def forward(
319319
self,
320320
inp,
321+
disable_triton_kernel = False,
321322
sliding_window_flex_mask = None,
322323
fine_selection_flex_mask = None
323324
):
@@ -441,7 +442,7 @@ def forward(
441442
gates = gates.cumprod(dim = -1)[..., -1]
442443
gates = repeat(gates, 'b h ... -> b (h qh) ...', qh = fine_num_grouped_queries)
443444

444-
if self.use_triton_kernel:
445+
if self.use_triton_kernel and not disable_triton_kernel:
445446
from native_sparse_attention_pytorch.triton_native_sparse_attention import native_sparse_attend
446447

447448
fmask = selected_importance_values > 1e-10

native_sparse_attention_pytorch/transformer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ def forward(
182182
self,
183183
ids,
184184
return_loss = False,
185-
disable_flex = False
185+
disable_flex = False,
186+
disable_triton_kernel = True
186187
):
187188
if return_loss:
188189
ids, labels = ids[:, :-1], ids[:, 1:]
@@ -195,7 +196,9 @@ def forward(
195196

196197
# prepare maybe flex attention masks
197198

198-
attn_kwargs = dict()
199+
attn_kwargs = dict(
200+
disable_triton_kernel = disable_triton_kernel
201+
)
199202

200203
if not disable_flex and self.use_flex_sliding_window:
201204
attn_kwargs.update(

train_triton_nsa.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
LEARNING_RATE = 1e-4
2626
VALIDATE_EVERY = 100
2727
PRIME_LENGTH = 64
28-
SHOULD_GENERATE = False
2928
GENERATE_EVERY = 500
3029
GENERATE_LENGTH = 512
3130
SEQ_LEN = 512
@@ -100,7 +99,11 @@ def base_decoding(
10099
sample_num_times = max(0, seq_len - prompt_seq_len)
101100

102101
for _ in tqdm(range(sample_num_times)):
103-
logits = net(out, disable_flex = True)
102+
logits = net(
103+
out,
104+
disable_flex = True,
105+
disable_triton_kernel = True
106+
)
104107

105108
logits = logits[:, -1]
106109
logits = top_k(logits, thres = filter_thres)
@@ -208,7 +211,7 @@ def __getitem__(self, index):
208211
wandb.log(dict(valid_loss = loss.item()), step = i)
209212
print(f"validation loss: {loss.item():.3f}")
210213

211-
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
214+
if i % GENERATE_EVERY == 0:
212215
model.eval()
213216

214217
inp = random.choice(val_dataset)[:PRIME_LENGTH]

0 commit comments

Comments
 (0)