Skip to content

Commit 2b9c54d

Browse files
committed
small cleanup
1 parent e1d9419 commit 2b9c54d

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,16 @@ def round_up_mult(n, mult):
8282
def divisible_by(num, den):
8383
return (num % den) == 0
8484

85+
# tensor helpers
86+
8587
def pad_at_dim(t, pad, dim = -1, value = 0.):
8688
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
8789
zeros = ((0, 0) * dims_from_right)
8890
return F.pad(t, (*zeros, *pad), value = value)
8991

92+
def straight_through(t, target):
93+
return t + (target - t).detach()
94+
9095
# classes
9196

9297
class SparseAttention(Module):
@@ -306,7 +311,7 @@ def forward(
306311
selected_importance_values, selected_block_indices = importance_scores.topk(num_selected, dim = -1)
307312

308313
if self.use_diff_topk:
309-
gates = selected_importance_values + (1. - selected_importance_values).detach()
314+
gates = straight_through(selected_importance_values, 1.)
310315

311316
fmask = selected_importance_values > 1e-10
312317

0 commit comments

Comments
 (0)