Skip to content

Commit 8ac5b7a

Browse files
committed
move head to first dimension
1 parent 2eb7cf6 commit 8ac5b7a

File tree

1 file changed

+30
-34
lines changed

1 file changed

+30
-34
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ def exists(v):
1616
def default(val, d):
1717
return val if exists(val) else d
1818

19+
def divisible_by(num, den):
20+
return (num % den) == 0
21+
1922
def round_up_multiple(n, mult):
2023
return ceil(n / mult) * mult
2124

@@ -49,8 +52,8 @@ def is_contiguous(x: Tensor):
4952

5053
@triton.heuristics(
5154
{
52-
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK"] == 0,
53-
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK"] == 0,
55+
"EVEN_M": lambda args: divisible_by(args["seqlen_q"], args["BLOCK"]),
56+
"EVEN_N": lambda args: divisible_by(args["seqlen_k"], args["BLOCK"]),
5457
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
5558
}
5659
)
@@ -335,14 +338,14 @@ def flash_attn_forward(
335338
):
336339
q, k, v, kv_block_indices = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v, kv_block_indices)]
337340

338-
batch, seqlen_q, nheads, dim = q.shape
339-
_, seqlen_k, _, _ = k.shape
341+
batch, nheads, seqlen_q, dim, device = *q.shape, q.device
342+
_, _, seqlen_k, _ = k.shape
340343

341344
num_selected_fine_blocks = kv_block_indices.shape[-1]
342345
assert kv_block_indices.shape == kv_block_mask.shape
343346

344-
assert k.shape == (batch, seqlen_k, nheads, dim)
345-
assert v.shape == (batch, seqlen_k, nheads, dim)
347+
assert k.shape == (batch, nheads, seqlen_k, dim)
348+
assert v.shape == (batch, nheads, seqlen_k, dim)
346349
assert dim <= 128, "only support head dimensions up to 128"
347350
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
348351
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
@@ -352,9 +355,9 @@ def flash_attn_forward(
352355

353356
seqlen_q_rounded = round_up_multiple(seqlen_q, TRITON_BLOCK_SIZE)
354357

355-
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
358+
lse = torch.empty((batch, nheads, seqlen_q_rounded), device = device, dtype = torch.float32)
356359

357-
m = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
360+
m = torch.empty((batch, nheads, seqlen_q_rounded), device = device, dtype = torch.float32)
358361

359362
o = torch.empty_like(q)
360363

@@ -373,20 +376,20 @@ def flash_attn_forward(
373376
lse,
374377
softmax_scale,
375378
q.stride(0),
376-
q.stride(2),
377379
q.stride(1),
380+
q.stride(2),
378381
k.stride(0),
379-
k.stride(2),
380382
k.stride(1),
383+
k.stride(2),
381384
v.stride(0),
382-
v.stride(2),
383385
v.stride(1),
386+
v.stride(2),
384387
o.stride(0),
385-
o.stride(2),
386388
o.stride(1),
389+
o.stride(2),
387390
kv_block_indices.stride(0),
388-
kv_block_indices.stride(2),
389391
kv_block_indices.stride(1),
392+
kv_block_indices.stride(2),
390393
nheads,
391394
seqlen_q,
392395
seqlen_k,
@@ -964,8 +967,8 @@ def flash_attn_backward(
964967
if not is_contiguous(do):
965968
do = do.contiguous()
966969

967-
batch, seqlen_q, nheads, dim = q.shape
968-
_, seqlen_k, _, _ = k.shape
970+
batch, nheads, seqlen_q, dim = q.shape
971+
_, _, seqlen_k, _ = k.shape
969972

970973
num_sel_fine_blocks = kv_block_indices.shape[-1]
971974
assert kv_block_indices.shape == kv_block_mask.shape
@@ -995,11 +998,11 @@ def flash_attn_backward(
995998
do,
996999
delta,
9971000
o.stride(0),
998-
o.stride(2),
9991001
o.stride(1),
1002+
o.stride(2),
10001003
do.stride(0),
1001-
do.stride(2),
10021004
do.stride(1),
1005+
do.stride(2),
10031006
nheads,
10041007
seqlen_q,
10051008
seqlen_q_rounded,
@@ -1027,29 +1030,29 @@ def flash_attn_backward(
10271030
delta,
10281031
softmax_scale,
10291032
q.stride(0),
1030-
q.stride(2),
10311033
q.stride(1),
1034+
q.stride(2),
10321035
k.stride(0),
1033-
k.stride(2),
10341036
k.stride(1),
1037+
k.stride(2),
10351038
v.stride(0),
1036-
v.stride(2),
10371039
v.stride(1),
1040+
v.stride(2),
10381041
do.stride(0),
1039-
do.stride(2),
10401042
do.stride(1),
1043+
do.stride(2),
10411044
dq_accum.stride(0),
1042-
dq_accum.stride(2),
10431045
dq_accum.stride(1),
1046+
dq_accum.stride(2),
10441047
dk.stride(0),
1045-
dk.stride(2),
10461048
dk.stride(1),
1049+
dk.stride(2),
10471050
dv.stride(0),
1048-
dv.stride(2),
10491051
dv.stride(1),
1052+
dv.stride(2),
10501053
kv_block_indices.stride(0),
1051-
kv_block_indices.stride(2),
10521054
kv_block_indices.stride(1),
1055+
kv_block_indices.stride(2),
10531056
nheads,
10541057
seqlen_q,
10551058
seqlen_k,
@@ -1063,8 +1066,8 @@ def flash_attn_backward(
10631066
BLOCK = block_size,
10641067
NUM_SEL_KV_BLOCKS = num_sel_fine_blocks,
10651068
SEQUENCE_PARALLEL = False,
1066-
EVEN_M = (seqlen_q % block_size) == 0,
1067-
EVEN_N = (seqlen_k % block_size) == 0,
1069+
EVEN_M = divisible_by(seqlen_q, block_size),
1070+
EVEN_N = divisible_by(seqlen_k, block_size),
10681071
EVEN_HEADDIM = BLOCK_HEADDIM == dim
10691072
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
10701073
# num_warps=num_warps,
@@ -1093,10 +1096,6 @@ def forward(
10931096
fmask,
10941097
num_grouped_queries
10951098
):
1096-
selected_block_indices, fmask = tuple(rearrange(t, 'b h i sel -> b i h sel') for t in (selected_block_indices, fmask))
1097-
1098-
fq, fk, fv = tuple(rearrange(t, 'b h n d -> b n h d') for t in (fq, fk, fv))
1099-
11001099
dtype = fq.dtype
11011100

11021101
fq, fk, fv = tuple(t.half() for t in (fq, fk, fv))
@@ -1111,12 +1110,10 @@ def forward(
11111110
ctx.save_for_backward(fq, fk, fv, selected_block_indices, fmask, out, lse)
11121111
ctx._saved_variables = (block_size,)
11131112

1114-
out = rearrange(out, 'b n h d -> b h n d')
11151113
return out.type(dtype)
11161114

11171115
@classmethod
11181116
def backward(self, ctx, do):
1119-
do = rearrange(do, 'b h n d -> b n h d')
11201117

11211118
q, k, v, sel_block_indices, mask, out, lse = ctx.saved_tensors
11221119

@@ -1136,7 +1133,6 @@ def backward(self, ctx, do):
11361133
block_size = block_size
11371134
)
11381135

1139-
dq, dk, dv = tuple(rearrange(t, 'b n h d -> b h n d') for t in (dq, dk, dv))
11401136
return dq, dk, dv, None, None, None, None
11411137

11421138
native_sparse_attend = NSA.apply

0 commit comments

Comments
 (0)