Skip to content

Commit 860517c

Browse files
committed
remove the get_at from inference
1 parent 56cce5d commit 860517c

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed
Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
11
from native_sparse_attention_pytorch.native_sparse_attention import (
22
SparseAttention
33
)
4-
5-
from native_sparse_attention_pytorch.triton_native_sparse_attention import (
6-
native_sparse_attend
7-
)

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,14 @@ def forward_inference(
418418
sel_fk = rearrange(sel_fk, 'b h (w j) d -> b h w j d', j = self.selection_block_size)
419419
sel_fv = rearrange(sel_fv, 'b h (w j) d -> b h w j d', j = self.selection_block_size)
420420

421-
sel_fk = einx.get_at('b h [w] j d, b h 1 sel -> b h (sel j) d', sel_fk, sel_indices)
422-
sel_fv = einx.get_at('b h [w] j d, b h 1 sel -> b h (sel j) d', sel_fv, sel_indices)
421+
# get_at('b h [w] j d, b h 1 sel -> b h (sel j) d'
422+
423+
sel_indices = repeat(sel_indices, 'b h 1 sel -> b h sel j d', j = self.selection_block_size, d = sel_fk.shape[-1])
424+
425+
sel_fk = sel_fk.gather(2, sel_indices)
426+
sel_fv = sel_fv.gather(2, sel_indices)
427+
428+
sel_fk, sel_fv = tuple(rearrange(t, 'b h sel j d -> b h (sel j) d') for t in (sel_fk, sel_fv))
423429

424430
fmask = sel_scores > 1e-10
425431

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.64"
3+
version = "0.0.65"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)