Skip to content

Commit 1f11855

Browse files
committed
fix some padding issues for gating with importance score
1 parent 77e55aa commit 1f11855

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,9 @@ def forward(
657657

658658
selected_block_indices = pad_at_dim(selected_block_indices, (0, remainder), value = 0, dim = -2)
659659

660+
if exists(gates):
661+
gates = pad_at_dim(gates, (0, remainder), value = 0, dim = -2)
662+
660663
# handle block causal diagonal in the diagram, but run experiments without to see
661664

662665
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
@@ -693,6 +696,7 @@ def forward(
693696
# differential topk gating
694697

695698
if self.use_diff_topk:
699+
gates = F.pad(gates, (0, 1), value = 1.)
696700
fk = einx.multiply('b h i sel, b h i sel j d -> b h i sel j d', gates, fk)
697701

698702
# merge selected key values

0 commit comments

Comments
 (0)