Skip to content

Commit 4ae5c11

Browse files
committed
feedback from researchers have been positive about tensor typing, so will use it here
1 parent db38d04 commit 4ae5c11

File tree

3 files changed

+43
-4
lines changed

3 files changed

+43
-4
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from torch import Tensor
2+
3+
from jaxtyping import (
4+
Float,
5+
Int,
6+
Bool
7+
)
8+
9+
# jaxtyping is a misnomer, works for pytorch
10+
11+
class TorchTyping:
12+
def __init__(self, abstract_dtype):
13+
self.abstract_dtype = abstract_dtype
14+
15+
def __getitem__(self, shapes: str):
16+
return self.abstract_dtype[Tensor, shapes]
17+
18+
Float = TorchTyping(Float)
19+
Int = TorchTyping(Int)
20+
Bool = TorchTyping(Bool)
21+
22+
__all__ = [
23+
Float,
24+
Int,
25+
Bool
26+
]

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
from native_sparse_attention_pytorch.tensor_typing import Float, Int, Bool
23

34
# taken from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
45
# with fixes for triton 2.3
@@ -1462,11 +1463,22 @@ def backward(self, ctx, do, _):
14621463

14631464
_native_sparse_attend = NSA.apply
14641465

1466+
# ein notation
1467+
1468+
# b - batch
1469+
# qh - query heads
1470+
# kh - key / value heads
1471+
# n - token sequence
1472+
# d - attention head dimension
1473+
# sel - selected indices
1474+
14651475
def native_sparse_attend(
1466-
fq, fk, fv,
1467-
block_size,
1468-
selected_block_indices,
1469-
fmask,
1476+
fq: Float['b qh n d'],
1477+
fk: Float['b kh n d'],
1478+
fv: Float['b kh n d'],
1479+
block_size: int,
1480+
selected_block_indices: Int['b qh sel'] | Int['b kh sel'],
1481+
fmask: Bool['b qh sel'] | Bool['b kh sel'],
14701482
return_lse = False
14711483
):
14721484
seq_len = fq.shape[-2]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ classifiers=[
2525
dependencies = [
2626
"einx>=0.3.0",
2727
"einops>=0.8.1",
28+
"jaxtyping",
2829
"local-attention>=1.11.1",
2930
"rotary-embedding-torch",
3031
"torch>=2.5",

0 commit comments

Comments
 (0)