Skip to content

Commit c391705

Browse files
committed
handle rotary embeddings for sliding windows explicitly
1 parent 232f4eb commit c391705

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def __init__(
115115
window_size = sliding_window_size,
116116
causal = True,
117117
exact_windowsize = True,
118-
autopad = True
118+
autopad = True,
119+
use_rotary_pos_emb = False
119120
)
120121

121122
self.sliding_window_size = sliding_window_size
@@ -234,6 +235,10 @@ def forward(
234235

235236
compressed_attn_out = einsum(cattn, cv, 'b h i j, b h j d -> b h i d')
236237

238+
# for 2. and 3., will give them relative positions with rotary - compressed needs to be handled separately (even if they already have intra block absolute positions)
239+
240+
rotated_q, rotated_k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
241+
237242
# 2. fine attention over selected based on compressed attention logits
238243

239244
importance_scores = cattn[..., num_mem_compress_kv:]
@@ -245,12 +250,10 @@ def forward(
245250

246251
fmask = selected_importance_values > 1e-10
247252

248-
fq = q
249-
fk = k
253+
fq = rotated_q
254+
fk = rotated_k
250255
fv = v
251256

252-
fq, fk = self.rotary_emb.rotate_queries_with_cached_keys(fq, fk)
253-
254257
if seq_len < fine_divisible_seq_len:
255258
remainder = fine_divisible_seq_len - seq_len
256259
fk = pad_at_dim(fk, (0, remainder), value = 0., dim = -2)
@@ -319,9 +322,9 @@ def forward(
319322
# 3. overlapping sliding window, this is unsurprising and expected
320323

321324
if exists(sliding_window_flex_mask):
322-
sliding_window_attn_out = flex_attention(q, k, v, block_mask = sliding_window_flex_mask)
325+
sliding_window_attn_out = flex_attention(rotated_q, rotated_k, v, block_mask = sliding_window_flex_mask)
323326
else:
324-
sliding_window_attn_out = self.sliding_window(q, k, v)
327+
sliding_window_attn_out = self.sliding_window(rotated_q, rotated_k, v)
325328

326329
# combine strategies
327330

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.14"
3+
version = "0.0.15"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)