Skip to content

Commit 9db5fec

Browse files
authored
Merge pull request #225 from flash-algo/optime-triton-kernels
Add utility functions for device management and input validation
2 parents a786fa8 + 9bdb282 commit 9db5fec

File tree

2 files changed

+373
-69
lines changed

2 files changed

+373
-69
lines changed

flash_sparse_attn/ops/triton/flash_fwd.py

Lines changed: 148 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,16 @@
33
import triton
44
import triton.language as tl
55

6-
from flash_sparse_attn.ops.triton import seqlen_info, block_info, mask, softmax
6+
from flash_sparse_attn.ops.triton import utils, seqlen_info, block_info, mask, softmax
7+
8+
9+
fwd_base_autotune_configs = utils.get_fwd_base_autotune_configs(True)
710

811

912
@triton.autotune(
10-
configs=[
11-
triton.Config({"TILE_M": 128, "TILE_N": 128}, num_warps=4, num_stages=1),
12-
triton.Config({"TILE_M": 128, "TILE_N": 64}, num_warps=4, num_stages=1),
13-
triton.Config({"TILE_M": 64, "TILE_N": 64}, num_warps=4, num_stages=1),
14-
triton.Config({"TILE_M": 128, "TILE_N": 128}, num_warps=4, num_stages=2),
15-
triton.Config({"TILE_M": 128, "TILE_N": 64}, num_warps=4, num_stages=2),
16-
triton.Config({"TILE_M": 64, "TILE_N": 64}, num_warps=4, num_stages=2),
17-
],
18-
key=["IS_CAUSAL", "IS_LOCAL", "TILE_K"],
13+
configs=fwd_base_autotune_configs,
14+
key=utils.FWD_BASE_AUTOTUNE_KEYS,
15+
use_cuda_graph=True,
1916
)
2017
@triton.jit
2118
def _fwd_base_kernel(
@@ -216,6 +213,18 @@ def _fwd_base_kernel(
216213
)
217214
else:
218215
tl.store(lse_ptrs, lse_tile, boundary_check=(0,))
216+
217+
# We can't get dtype of query for output here, so we initialize output to zero
218+
# # Write output as zero for proper handling
219+
# if PACK_GQA:
220+
# tl.store(
221+
# out_ptrs,
222+
# o_tile,
223+
# mask=((offs_m // QHEADS_PER_KVHEAD_PACKGQA) < actual_seqlen_q)[:, None]
224+
# & (offs_kb < head_dim)[None, :],
225+
# )
226+
# else:
227+
# tl.store(out_ptrs, o_tile, boundary_check=(0, 1))
219228
return
220229

221230
# Create query pointers
@@ -454,82 +463,151 @@ def _fwd_base_kernel(
454463
tl.store(out_ptrs, acc_o.to(q_tile.dtype), boundary_check=(0, 1))
455464

456465

457-
def _flash_attn_forward(
466+
def _flash_attn_base_forward(
458467
query: torch.Tensor,
459468
key: torch.Tensor,
460469
value: torch.Tensor,
461470
softmax_scale: float,
462471
is_causal: bool = False,
463472
window_size: Optional[Tuple[int, int]] = None,
464-
cu_seqlens_q: Optional[torch.Tensor] = None,
465-
cu_seqlens_k: Optional[torch.Tensor] = None,
466-
max_seqlen_q: Optional[int] = None,
467-
max_seqlen_k: Optional[int] = None,
468473
pack_gqa: bool = False,
469474
):
470-
is_varlen = cu_seqlens_q is not None and cu_seqlens_k is not None
471-
if not is_varlen:
472-
batch_size, seqlen_q, num_heads_q, head_dim = query.shape
473-
_, seqlen_k, num_heads_kv, _ = key.shape
474-
else:
475-
total_seqlen_q, num_heads_q, head_dim = query.shape
476-
_, num_heads_kv, _ = key.shape
477-
batch_size = cu_seqlens_q.shape[0] - 1
478-
seqlen_q = max_seqlen_q
479-
seqlen_k = max_seqlen_k
475+
batch_size, seqlen_q, num_heads_q, head_dim = query.shape
476+
_, seqlen_k, num_heads_kv, _ = key.shape
480477

481478
is_local = window_size[0] is not None or window_size[1] is not None
482479
if is_local:
483480
window_size_left, window_size_right = window_size
484481
else:
485482
window_size_left, window_size_right = None, None
486483

487-
assert query.is_cuda and key.is_cuda and value.is_cuda, (
488-
"All inputs must be on CUDA device"
484+
utils.assert_fwd_base_inputs(
485+
query,
486+
key,
487+
value,
488+
cu_seqlens_q=None,
489+
cu_seqlens_k=None,
490+
num_heads_q=num_heads_q,
491+
num_heads_kv=num_heads_kv,
492+
head_dim=head_dim,
489493
)
490-
assert query.dtype in [torch.float16, torch.bfloat16], (
491-
"Input dtype must be float16 or bfloat16"
494+
495+
softmax_scale = softmax_scale or 1.0 / (head_dim**0.5)
496+
497+
out = torch.zeros_like(query)
498+
lse = torch.empty(
499+
(batch_size, num_heads_q, seqlen_q),
500+
device=query.device,
501+
dtype=torch.float32,
492502
)
493-
assert query.dtype == key.dtype == value.dtype, (
494-
"All inputs must have the same dtype"
503+
504+
TILE_K = max(triton.next_power_of_2(head_dim), 16)
505+
506+
grid = utils.get_fwd_base_grid(
507+
batch_size=batch_size,
508+
seqlen_q=seqlen_q,
509+
num_heads_q=num_heads_q,
510+
num_heads_kv=num_heads_kv,
511+
pack_gqa=pack_gqa,
495512
)
496-
assert num_heads_q % num_heads_kv == 0, (
497-
"num_heads_q must be divisible by num_heads_kv"
513+
514+
_fwd_base_kernel[grid](
515+
query,
516+
key,
517+
value,
518+
out,
519+
lse,
520+
softmax_scale,
521+
query.stride(0),
522+
query.stride(-2),
523+
query.stride(-3),
524+
key.stride(0),
525+
key.stride(-2),
526+
key.stride(-3),
527+
value.stride(0),
528+
value.stride(-2),
529+
value.stride(-3),
530+
out.stride(0),
531+
out.stride(-2),
532+
out.stride(-3),
533+
lse.stride(0),
534+
lse.stride(1),
535+
None,
536+
None,
537+
None,
538+
None,
539+
num_heads_q // num_heads_kv,
540+
seqlen_q,
541+
seqlen_k,
542+
head_dim,
543+
QHEADS_PER_KVHEAD_PACKGQA=(num_heads_q // num_heads_kv) if pack_gqa else 1,
544+
TILE_K=TILE_K,
545+
IS_CAUSAL=is_causal,
546+
IS_LOCAL=is_local,
547+
WINDOW_SIZE_LEFT=window_size_left,
548+
WINDOW_SIZE_RIGHT=window_size_right,
549+
HAS_CU_SEQLENS_Q=False,
550+
HAS_CU_SEQLENS_K=False,
551+
HAS_SEQUSED_Q=False,
552+
HAS_SEQUSED_K=False,
553+
PACK_GQA=pack_gqa,
498554
)
499-
assert head_dim % 16 == 0, (
500-
"head_dim must be a multiple of 16 for efficient memory access"
555+
556+
return out, lse, softmax_scale
557+
558+
559+
def _flash_attn_varlen_base_forward(
560+
query: torch.Tensor,
561+
key: torch.Tensor,
562+
value: torch.Tensor,
563+
cu_seqlens_q: torch.Tensor,
564+
cu_seqlens_k: torch.Tensor,
565+
max_seqlen_q: int,
566+
max_seqlen_k: int,
567+
softmax_scale: float,
568+
is_causal: bool = False,
569+
window_size: Optional[Tuple[int, int]] = None,
570+
pack_gqa: bool = False,
571+
):
572+
total_seqlen_q, num_heads_q, head_dim = query.shape
573+
_, num_heads_kv, _ = key.shape
574+
batch_size = cu_seqlens_q.shape[0] - 1
575+
seqlen_q = max_seqlen_q
576+
seqlen_k = max_seqlen_k
577+
578+
is_local = window_size[0] is not None or window_size[1] is not None
579+
if is_local:
580+
window_size_left, window_size_right = window_size
581+
else:
582+
window_size_left, window_size_right = None, None
583+
584+
utils.assert_fwd_base_inputs(
585+
query,
586+
key,
587+
value,
588+
cu_seqlens_q=cu_seqlens_q,
589+
cu_seqlens_k=cu_seqlens_k,
590+
num_heads_q=num_heads_q,
591+
num_heads_kv=num_heads_kv,
592+
head_dim=head_dim,
501593
)
502-
assert head_dim <= 256, "head_dim must be less than or equal to 256"
503-
if is_varlen:
504-
assert (
505-
cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
506-
), "cu_seqlens_q and cu_seqlens_k must be of int32"
507594

508595
softmax_scale = softmax_scale or 1.0 / (head_dim**0.5)
509596

510597
out = torch.zeros_like(query)
511-
if not is_varlen:
512-
lse = torch.empty(
513-
(batch_size, num_heads_q, seqlen_q),
514-
device=query.device,
515-
dtype=torch.float32,
516-
)
517-
else:
518-
lse = torch.empty(
519-
(total_seqlen_q, num_heads_q), device=query.device, dtype=torch.float32
520-
)
598+
lse = torch.empty(
599+
(total_seqlen_q, num_heads_q), device=query.device, dtype=torch.float32
600+
)
521601

522602
TILE_K = max(triton.next_power_of_2(head_dim), 16)
523603

524-
def grid(META):
525-
return (
526-
triton.cdiv(
527-
seqlen_q * (num_heads_q // num_heads_kv) if pack_gqa else seqlen_q,
528-
META["TILE_M"],
529-
),
530-
num_heads_kv if pack_gqa else num_heads_q,
531-
batch_size,
532-
)
604+
grid = utils.get_fwd_base_grid(
605+
batch_size=batch_size,
606+
seqlen_q=seqlen_q,
607+
num_heads_q=num_heads_q,
608+
num_heads_kv=num_heads_kv,
609+
pack_gqa=pack_gqa,
610+
)
533611

534612
_fwd_base_kernel[grid](
535613
query,
@@ -538,18 +616,18 @@ def grid(META):
538616
out,
539617
lse,
540618
softmax_scale,
541-
query.stride(0) if not is_varlen else 0,
619+
0,
542620
query.stride(-2),
543-
query.stride(-3) if not is_varlen else query.stride(0),
544-
key.stride(0) if not is_varlen else 0,
621+
query.stride(0),
622+
0,
545623
key.stride(-2),
546-
key.stride(-3) if not is_varlen else key.stride(0),
547-
value.stride(0) if not is_varlen else 0,
624+
key.stride(0),
625+
0,
548626
value.stride(-2),
549-
value.stride(-3) if not is_varlen else value.stride(0),
550-
out.stride(0) if not is_varlen else 0,
627+
value.stride(0),
628+
0,
551629
out.stride(-2),
552-
out.stride(-3) if not is_varlen else out.stride(0),
630+
out.stride(0),
553631
lse.stride(0),
554632
lse.stride(1),
555633
cu_seqlens_q,
@@ -566,10 +644,11 @@ def grid(META):
566644
IS_LOCAL=is_local,
567645
WINDOW_SIZE_LEFT=window_size_left,
568646
WINDOW_SIZE_RIGHT=window_size_right,
569-
HAS_CU_SEQLENS_Q=is_varlen,
570-
HAS_CU_SEQLENS_K=is_varlen,
647+
HAS_CU_SEQLENS_Q=True,
648+
HAS_CU_SEQLENS_K=True,
571649
HAS_SEQUSED_Q=False,
572650
HAS_SEQUSED_K=False,
573651
PACK_GQA=pack_gqa,
574652
)
653+
575654
return out, lse, softmax_scale

0 commit comments

Comments
 (0)