Skip to content

Commit 0d96765

Browse files
authored
Merge pull request #4 from lancerts/patch-1
Update native_sparse_attention.py
2 parents 6ee9078 + 2808e7e commit 0d96765

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def forward(
277277
k_compress_input = self.split_compress_window(k[..., :compress_divisible_seq_len, :] + k_pos)
278278
v_compress_input = self.split_compress_window(v[..., :compress_divisible_seq_len, :] + v_pos)
279279

280-
ck = self.k_compress(k_compress_input)
280+
ck = self.k_compress(k_compress_input) # Equation (7) of the Native Sparse Attention paper
281281
cv = self.v_compress(v_compress_input)
282282

283283
# 1. coarse attention over compressed
@@ -320,7 +320,7 @@ def forward(
320320

321321
importance_scores = reduce(importance_scores, 'b (grouped_queries h) ... -> b h ...', 'mean', grouped_queries = self.num_grouped_queries)
322322

323-
# handle if compress block size not equal to the fine block size
323+
# handle if compress block size does not equal to the fine block size
324324
# cannot parse their equation, so will just improvise
325325
# first we expand all the compressed scores to the full sequence length, then average within each fine / selection block size - pad on the right to 0s, which should be fine as sliding window convers the local anyways
326326

0 commit comments

Comments
 (0)