22import triton
33import triton .language as tl
44import math
5+ import time
56import torch .nn .functional as F
7+ from typing import Optional , Tuple
68from lightllm .utils .device_utils import is_hopper
79
810if triton .__version__ >= "2.1.0" :
@@ -82,9 +84,7 @@ def _fwd_kernel(
8284 + cur_head * v_stride_h
8385 + offs_d [None , :] * v_stride_d
8486 )
85- v = tl .load (V + off_v , mask = ((start_n + offs_n [:, None ]) < seq_len ) & mask_d [None , :], other = 0.0 ).to (
86- tl .float32
87- )
87+ v = tl .load (V + off_v , mask = ((start_n + offs_n [:, None ]) < seq_len ) & mask_d [None , :], other = 0.0 )
8888 p = p .to (v .dtype )
8989 acc += tl .dot (p , v )
9090 # update m_i and l_i
@@ -106,36 +106,17 @@ def _flash_attention_triton_fwd(
106106 k ,
107107 v ,
108108 o ,
109- cu_seqlens = None , # q k v cu_seqlens,
110- max_seqlen = None ,
109+ cu_seqlens , # q k v cu_seqlens,
110+ max_seqlen ,
111111 ):
112112 BLOCK = 64
113113 # shape constraints
114- assert q .shape == k .shape == v .shape == o .shape , "q, k, v, o must have the same shape"
115-
116- if q .ndim == 4 :
117- bs , seq_len , head_num , head_dim = q .shape
118- total_len = bs * seq_len
119- reshape_fn = lambda t : t .view (total_len , head_num , head_dim )
120- q , k , v , o = [reshape_fn (x ) for x in (q , k , v , o )]
121- elif q .ndim == 3 :
122- total_len , head_num , head_dim = q .shape
123- else :
124- raise ValueError ("q,k,v,o must be 3d or 4d" )
125-
126- if cu_seqlens is None : # 说明是定长的
127- cu_seqlens = torch .arange (bs + 1 , dtype = torch .int32 , device = q .device ) * seq_len
128- else :
129- cu_seqlens = cu_seqlens .to (q .device , torch .int32 )
130-
131- if max_seqlen is None :
132- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
133-
114+ assert q .ndim == k .ndim == v .ndim == o .ndim == 3 , "q, k, v, o must be 3D tensors"
115+ _ , head_num , head_dim = q .shape
134116 batch_size = cu_seqlens .numel () - 1
135117
136- d_pad = triton .next_power_of_2 (head_dim )
137118 sm_scale = 1.0 / (head_dim ** 0.5 ) # 计算scale系数
138-
119+ d_pad = triton . next_power_of_2 ( head_dim )
139120 grid = (triton .cdiv (max_seqlen , BLOCK ), head_num , batch_size ) # batch, head,
140121 num_warps = 4
141122 _fwd_kernel [grid ](
@@ -180,18 +161,11 @@ def flash_attention_v3_fwd(
180161 k ,
181162 v ,
182163 o ,
183- cu_seqlens = None ,
184- max_seqlen = None ,
164+ cu_seqlens ,
165+ max_seqlen ,
185166 ):
186167 head_dim = q .shape [- 1 ]
187168 softmax_scale = head_dim ** - 0.5
188- if cu_seqlens is not None :
189- cu_seqlens = cu_seqlens .to (q .device , torch .int32 )
190- if q .ndim == 4 :
191- bs , seq_len , head_num , head_dim = q .shape
192- total_len = bs * seq_len
193- reshape_fn = lambda t : t .view (total_len , head_num , head_dim )
194- q , k , v , o = [reshape_fn (x ) for x in (q , k , v , o )]
195169 _flash_attn_forward (
196170 q ,
197171 k ,
@@ -229,7 +203,7 @@ def flash_attention_v3_fwd(
229203 _flash_attn_v3_available = False
230204
231205
232- def flash_attention_fwd (q , k , v , o , cu_seqlens = None , max_seqlen = None ):
206+ def flash_attention_fwd (q , k , v , o , cu_seqlens , max_seqlen ):
233207 """
234208 统一的 Flash Attention 接口。如果 _flash_attn_forward 存在,
235209 则使用 flash_attention_v3_fwd,否则使用 Triton 版本。
@@ -238,44 +212,3 @@ def flash_attention_fwd(q, k, v, o, cu_seqlens=None, max_seqlen=None):
238212 flash_attention_v3_fwd (q , k , v , o , cu_seqlens , max_seqlen )
239213 else :
240214 _flash_attention_triton_fwd (q , k , v , o , cu_seqlens , max_seqlen )
241-
242-
243- def torch_att (q , k , v ):
244- head_dim = q .shape [- 1 ]
245- q = q .transpose (1 , 2 )
246- k = k .transpose (1 , 2 )
247- v = v .transpose (1 , 2 )
248- scale = head_dim ** - 0.5
249- attn = (q * scale ) @ k .transpose (- 2 , - 1 )
250- attn = attn .softmax (dim = - 1 )
251- out = attn @ v
252- out = out .transpose (1 , 2 ).contiguous ()
253- return out
254-
255-
256- def test ():
257- import torch
258- import numpy as np
259-
260- B , L , H , D = 4 , 1025 , 7 , 128
261- dtype = torch .float16
262- q = torch .empty ((B , L , H , D ), dtype = dtype , device = "cuda" ).normal_ (mean = 0.1 , std = 0.2 )
263- k = torch .empty ((B , L , H , D ), dtype = dtype , device = "cuda" ).normal_ (mean = 0.1 , std = 0.2 )
264- v = torch .empty ((B , L , H , D ), dtype = dtype , device = "cuda" ).normal_ (mean = 0.1 , std = 0.2 )
265- o = torch .empty ((B , L , H , D ), dtype = dtype , device = "cuda" ).normal_ (mean = 0.1 , std = 0.2 )
266- torch_out = torch_att (q , k , v )
267- import time
268-
269- torch .cuda .synchronize ()
270- a = time .time ()
271- for i in range (100 ):
272- flash_attention_fwd (q , k , v , o )
273- # o = torch_att(q, k, v)
274- torch .cuda .synchronize ()
275- b = time .time ()
276- # print(o.shape, torch_out.shape)
277- print ((b - a ) / 100 * 1000 )
278-
279- print ("max " , torch .max (torch .abs (torch_out - o )))
280- print ("mean " , torch .mean (torch .abs (torch_out - o )))
281- assert torch .allclose (torch_out , o , atol = 1e-2 , rtol = 0 )
0 commit comments