Skip to content

Commit f32ed38

Browse files
committed
oops
1 parent cfa6e08 commit f32ed38

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def forward(
399399

400400
if self.use_diff_topk:
401401
gates = straight_through(selected_importance_values, 1.)
402-
gates = gates.cumsum(dim = -1)[..., -1]
402+
gates = gates.cumprod(dim = -1)[..., -1]
403403
gates = repeat(gates, 'b h ... -> b (h qh) ...', qh = self.num_grouped_queries)
404404

405405
if exists(fine_selection_flex_mask):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.39"
3+
version = "0.0.40"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def base_decoding(
116116
),
117117
selection_block_size = 32,
118118
num_selected_blocks = 2,
119-
use_diff_topk = False,
119+
use_diff_topk = True,
120120
interpolated_importance_score = True
121121
)
122122
).cuda()

0 commit comments

Comments
 (0)