@@ -13,26 +13,20 @@ def _fwd_kernel(
1313 K ,
1414 V ,
1515 sm_scale ,
16- seq_len ,
1716 Out ,
18- q_stride_b ,
1917 q_stride_s ,
2018 q_stride_h ,
2119 q_stride_d ,
22- k_stride_b ,
2320 k_stride_s ,
2421 k_stride_h ,
2522 k_stride_d ,
26- v_stride_b ,
2723 v_stride_s ,
2824 v_stride_h ,
2925 v_stride_d ,
30- o_stride_b ,
3126 o_stride_s ,
3227 o_stride_h ,
3328 o_stride_d ,
3429 head_dim_act ,
35- is_varlen : tl .constexpr ,
3630 cu_seqlens ,
3731 BLOCK_M : tl .constexpr ,
3832 BLOCK_DMODEL : tl .constexpr ,
@@ -42,29 +36,17 @@ def _fwd_kernel(
4236 cur_head = tl .program_id (1 )
4337 start_m = tl .program_id (0 )
4438
45- if is_varlen == 1 :
46- seq_start = tl .load (cu_seqlens + cur_batch ).to (tl .int32 )
47- seq_end = tl .load (cu_seqlens + cur_batch + 1 ).to (tl .int32 )
48- seq_len = seq_end - seq_start
49- q_stride_b = 0
50- k_stride_b = 0
51- v_stride_b = 0
52- o_stride_b = 0
53- else :
54- seq_start = 0
39+ seq_start = tl .load (cu_seqlens + cur_batch ).to (tl .int32 )
40+ seq_end = tl .load (cu_seqlens + cur_batch + 1 ).to (tl .int32 )
41+ seq_len = seq_end - seq_start
5542
5643 # initialize offsets
5744 offs_n = tl .arange (0 , BLOCK_N )
5845 offs_d = tl .arange (0 , BLOCK_DMODEL )
5946 offs_m = start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
6047
6148 mask_d = offs_d < head_dim_act
62- off_q = (
63- cur_batch * q_stride_b
64- + cur_head * q_stride_h
65- + (seq_start + offs_m [:, None ]) * q_stride_s
66- + offs_d [None , :] * q_stride_d
67- )
49+ off_q = cur_head * q_stride_h + (seq_start + offs_m [:, None ]) * q_stride_s + offs_d [None , :] * q_stride_d
6850 q = tl .load (Q + off_q , mask = (offs_m [:, None ] < seq_len ) & mask_d [None , :], other = 0.0 )
6951 # initialize pointer to m and l
7052 m_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 ) - float ("inf" )
@@ -75,15 +57,14 @@ def _fwd_kernel(
7557 start_n = tl .multiple_of (start_n , BLOCK_N )
7658 # -- compute qk ----
7759 off_k = (
78- cur_batch * k_stride_b
79- + (seq_start + start_n + offs_n [None , :]) * k_stride_s
60+ (seq_start + start_n + offs_n [None , :]) * k_stride_s
8061 + cur_head * k_stride_h
8162 + offs_d [:, None ] * k_stride_d
8263 )
8364 k = tl .load (K + off_k , mask = ((start_n + offs_n [None , :]) < seq_len ) & mask_d [:, None ], other = 0.0 )
8465
8566 qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
86- qk += tl .dot (q , k , out_dtype = tl . float32 , allow_tf32 = False )
67+ qk += tl .dot (q , k )
8768 qk *= sm_scale
8869 qk += tl .where ((start_n + offs_n [None , :]) < seq_len , 0 , float ("-inf" ))
8970
@@ -97,8 +78,7 @@ def _fwd_kernel(
9778
9879 # update acc
9980 off_v = (
100- cur_batch * v_stride_b
101- + (seq_start + start_n + offs_n [:, None ]) * v_stride_s
81+ (seq_start + start_n + offs_n [:, None ]) * v_stride_s
10282 + cur_head * v_stride_h
10383 + offs_d [None , :] * v_stride_d
10484 )
@@ -115,12 +95,7 @@ def _fwd_kernel(
11595 o_scale = tl .exp (m_i - l_i )
11696 acc = acc * o_scale [:, None ]
11797 # initialize pointers to output
118- off_o = (
119- cur_batch * o_stride_b
120- + (seq_start + offs_m [:, None ]) * o_stride_s
121- + cur_head * o_stride_h
122- + offs_d [None , :] * o_stride_d
123- )
98+ off_o = (seq_start + offs_m [:, None ]) * o_stride_s + cur_head * o_stride_h + offs_d [None , :] * o_stride_d
12499 out_ptrs = Out + off_o
125100 tl .store (out_ptrs , acc , mask = (offs_m [:, None ] < seq_len ) & mask_d [None , :])
126101 return
@@ -132,49 +107,57 @@ def _flash_attention_triton_fwd(
132107 v ,
133108 o ,
134109 cu_seqlens = None , # q k v cu_seqlens,
135- max_seqlens = None ,
110+ max_seqlen = None ,
136111 ):
137112 BLOCK = 64
138113 # shape constraints
114+ assert q .shape == k .shape == v .shape == o .shape , "q, k, v, o must have the same shape"
139115
140- batch_size , seq_len , head_num , head_dim = q .shape
141- if cu_seqlens is not None and max_seqlens is not None :
142- assert q .shape [0 ] == 1
116+ if q .ndim == 4 :
117+ bs , seq_len , head_num , head_dim = q .shape
118+ total_len = bs * seq_len
119+ reshape_fn = lambda t : t .view (total_len , head_num , head_dim )
120+ q , k , v , o = [reshape_fn (x ) for x in (q , k , v , o )]
121+ elif q .ndim == 3 :
122+ total_len , head_num , head_dim = q .shape
123+ else :
124+ raise ValueError ("q,k,v,o must be 3d or 4d" )
125+
126+ if cu_seqlens is None : # 说明是定长的
127+ cu_seqlens = torch .arange (bs + 1 , dtype = torch .int32 , device = q .device ) * seq_len
128+ else :
143129 cu_seqlens = cu_seqlens .to (q .device , torch .int32 )
144- seq_len = max_seqlens
145- batch_size = cu_seqlens .numel () - 1
130+
131+ if max_seqlen is None :
132+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
133+
134+ batch_size = cu_seqlens .numel () - 1
146135
147136 d_pad = triton .next_power_of_2 (head_dim )
148137 sm_scale = 1.0 / (head_dim ** 0.5 ) # 计算scale系数
149- # grid = (batch_size, head_num, triton.cdiv(seq_len, BLOCK)) # batch, head,
150- grid = (triton .cdiv (seq_len , BLOCK ), head_num , batch_size ) # batch, head,
138+
139+ grid = (triton .cdiv (max_seqlen , BLOCK ), head_num , batch_size ) # batch, head,
151140 num_warps = 4
152141 _fwd_kernel [grid ](
153142 q ,
154143 k ,
155144 v ,
156145 sm_scale ,
157- seq_len ,
158146 o ,
159147 q .stride (0 ),
160148 q .stride (1 ),
161149 q .stride (2 ),
162- q .stride (3 ),
163150 k .stride (0 ),
164151 k .stride (1 ),
165152 k .stride (2 ),
166- k .stride (3 ),
167153 v .stride (0 ),
168154 v .stride (1 ),
169155 v .stride (2 ),
170- v .stride (3 ),
171156 o .stride (0 ),
172157 o .stride (1 ),
173158 o .stride (2 ),
174- o .stride (3 ),
175159 head_dim ,
176- is_varlen = 1 if cu_seqlens is not None else 0 ,
177- cu_seqlens = 0 if cu_seqlens is None else cu_seqlens ,
160+ cu_seqlens ,
178161 BLOCK_M = BLOCK ,
179162 BLOCK_DMODEL = d_pad ,
180163 BLOCK_N = BLOCK ,
@@ -198,10 +181,17 @@ def flash_attention_v3_fwd(
198181 v ,
199182 o ,
200183 cu_seqlens = None ,
201- max_seqlens = None ,
184+ max_seqlen = None ,
202185 ):
203186 head_dim = q .shape [- 1 ]
204187 softmax_scale = head_dim ** - 0.5
188+ if cu_seqlens is not None :
189+ cu_seqlens = cu_seqlens .to (q .device , torch .int32 )
190+ if q .ndim == 4 :
191+ bs , seq_len , head_num , head_dim = q .shape
192+ total_len = bs * seq_len
193+ reshape_fn = lambda t : t .view (total_len , head_num , head_dim )
194+ q , k , v , o = [reshape_fn (x ) for x in (q , k , v , o )]
205195 _flash_attn_forward (
206196 q ,
207197 k ,
@@ -214,8 +204,8 @@ def flash_attention_v3_fwd(
214204 None , # cu_seqlens_q/k/k_new
215205 None ,
216206 None , # seqused_q/k
217- max_seqlens ,
218- max_seqlens , # max_seqlen_q/k
207+ max_seqlen ,
208+ max_seqlen , # max_seqlen_q/k
219209 None ,
220210 None ,
221211 None , # page_table, kv_batch_idx, leftpad_k,
@@ -239,15 +229,15 @@ def flash_attention_v3_fwd(
239229 _flash_attn_v3_available = False
240230
241231
242- def flash_attention_fwd (q , k , v , o , cu_seqlens = None , max_seqlens = None ):
232+ def flash_attention_fwd (q , k , v , o , cu_seqlens = None , max_seqlen = None ):
243233 """
244234 统一的 Flash Attention 接口。如果 _flash_attn_forward 存在,
245235 则使用 flash_attention_v3_fwd,否则使用 Triton 版本。
246236 """
247237 if _flash_attn_v3_available and is_hopper ():
248- flash_attention_v3_fwd (q , k , v , o , cu_seqlens , max_seqlens )
238+ flash_attention_v3_fwd (q , k , v , o , cu_seqlens , max_seqlen )
249239 else :
250- _flash_attention_triton_fwd (q , k , v , o , cu_seqlens , max_seqlens )
240+ _flash_attention_triton_fwd (q , k , v , o , cu_seqlens , max_seqlen )
251241
252242
253243def torch_att (q , k , v ):
0 commit comments