Skip to content

Commit 2808e7e

Browse files
authored
Update native_sparse_attention.py
1 parent 115279f commit 2808e7e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def forward(
278278
k_compress_input = self.split_compress_window(k[..., :compress_divisible_seq_len, :] + k_pos)
279279
v_compress_input = self.split_compress_window(v[..., :compress_divisible_seq_len, :] + v_pos)
280280

281-
ck = self.k_compress(k_compress_input)
281+
ck = self.k_compress(k_compress_input) # Equation (7) of the Native Sparse Attention paper
282282
cv = self.v_compress(v_compress_input)
283283

284284
# 1. coarse attention over compressed
@@ -321,7 +321,7 @@ def forward(
321321

322322
importance_scores = reduce(importance_scores, 'b (grouped_queries h) ... -> b h ...', 'mean', grouped_queries = self.num_grouped_queries)
323323

324-
# handle if compress block size not equal to the fine block size
324+
# handle if compress block size does not equal to the fine block size
325325
# cannot parse their equation, so will just improvise
326326
# first we expand all the compressed scores to the full sequence length, then average within each fine / selection block size - pad on the right to 0s, which should be fine as sliding window convers the local anyways
327327

0 commit comments

Comments
 (0)