1
1
# mypy: allow-untyped-defs
2
2
"""Triton Implementation of the flex_attention Kernel"""
3
3
4
+ from __future__ import annotations
5
+
4
6
import logging
5
7
import math
6
8
from collections .abc import Sequence
7
9
from dataclasses import dataclass
8
- from typing import Any , Optional , Union
10
+ from typing import Any , Optional , TYPE_CHECKING , Union
9
11
10
12
import sympy
11
13
35
37
from .flex_decoding import _use_flex_decoding , create_flex_decoding_kernel
36
38
37
39
40
+ if TYPE_CHECKING :
41
+ from ...template_heuristics .triton import FlexBwDConfig , FlexConfig
42
+
43
+
38
44
log = logging .getLogger (__name__ )
39
45
aten = torch .ops .aten
40
46
Expr = sympy .Expr
@@ -279,7 +285,7 @@ def flex_attention(
279
285
280
286
dtype = query .get_dtype ()
281
287
head_dim = V .graph .sizevars .guard_int (query .get_size ()[- 1 ])
282
- configs = V .choices .get_flex_attention_fwd_configs (
288
+ configs : list [ FlexConfig ] = V .choices .get_flex_attention_fwd_configs (
283
289
head_dim , dtype , query .get_device ().type
284
290
)
285
291
@@ -719,20 +725,21 @@ def flex_attention_backward(*args, **kwargs):
719
725
720
726
dtype = query .get_dtype ()
721
727
head_dim = V .graph .sizevars .guard_int (query .get_size ()[- 1 ])
722
- configs = V .choices .get_flex_attention_bwd_configs (
728
+ configs : list [ FlexBwDConfig ] = V .choices .get_flex_attention_bwd_configs (
723
729
head_dim , dtype , query .get_device ().type
724
730
)
725
731
726
732
# Default config for warp specialization
727
733
num_consumer_groups , num_buffers_warp_spec = 0 , 0
728
734
729
735
original_kernel_options = kernel_options .copy ()
736
+
730
737
for conf in configs :
731
738
if (
732
- SPARSE_KV_BLOCK_SIZE % conf .block_m != 0
733
- or SPARSE_Q_BLOCK_SIZE % conf .block_m != 0
734
- or SPARSE_KV_BLOCK_SIZE % conf .block_n != 0
735
- or SPARSE_Q_BLOCK_SIZE % conf .block_n != 0
739
+ SPARSE_KV_BLOCK_SIZE % conf .block_n1 != 0
740
+ or SPARSE_Q_BLOCK_SIZE % conf .block_m1 != 0
741
+ or SPARSE_KV_BLOCK_SIZE % conf .block_n2 != 0
742
+ or SPARSE_Q_BLOCK_SIZE % conf .block_m2 != 0
736
743
):
737
744
continue
738
745
@@ -755,10 +762,10 @@ def flex_attention_backward(*args, **kwargs):
755
762
"num_buffers_warp_spec" , num_buffers_warp_spec
756
763
)
757
764
758
- cur_kernel_options .setdefault ("BLOCK_M1" , conf .block_m )
759
- cur_kernel_options .setdefault ("BLOCK_N1" , conf .block_n )
760
- cur_kernel_options .setdefault ("BLOCK_M2" , conf .block_n )
761
- cur_kernel_options .setdefault ("BLOCK_N2" , conf .block_m )
765
+ cur_kernel_options .setdefault ("BLOCK_M1" , conf .block_m1 )
766
+ cur_kernel_options .setdefault ("BLOCK_N1" , conf .block_n1 )
767
+ cur_kernel_options .setdefault ("BLOCK_M2" , conf .block_m2 )
768
+ cur_kernel_options .setdefault ("BLOCK_N2" , conf .block_n2 )
762
769
763
770
# Blocksparse options
764
771
cur_kernel_options .setdefault ("SPARSE_Q_BLOCK_SIZE" , SPARSE_Q_BLOCK_SIZE )
0 commit comments