@@ -22,6 +22,14 @@ def _fwd_kernel(
2222 q_stride_s ,
2323 q_stride_h ,
2424 q_stride_d ,
25+ k_stride_b ,
26+ k_stride_s ,
27+ k_stride_h ,
28+ k_stride_d ,
29+ v_stride_b ,
30+ v_stride_s ,
31+ v_stride_h ,
32+ v_stride_d ,
2533 o_stride_b ,
2634 o_stride_s ,
2735 o_stride_h ,
@@ -30,9 +38,9 @@ def _fwd_kernel(
3038 BLOCK_DMODEL : tl .constexpr ,
3139 BLOCK_N : tl .constexpr ,
3240 ):
33- cur_batch = tl .program_id (0 )
41+ cur_batch = tl .program_id (2 )
3442 cur_head = tl .program_id (1 )
35- start_m = tl .program_id (2 )
43+ start_m = tl .program_id (0 )
3644
3745 # initialize offsets
3846 offs_n = tl .arange (0 , BLOCK_N )
@@ -49,9 +57,9 @@ def _fwd_kernel(
4957 start_n = tl .multiple_of (start_n , BLOCK_N )
5058 # -- compute qk ----
5159 off_k = (
52- cur_batch * q_stride_b
53- + (start_n + offs_n [None , :]) * q_stride_s
54- + cur_head * q_stride_h
60+ cur_batch * k_stride_b
61+ + (start_n + offs_n [None , :]) * k_stride_s
62+ + cur_head * k_stride_h
5563 + offs_d [:, None ]
5664 )
5765 k = tl .load (K + off_k , mask = (start_n + offs_n [None , :]) < seq_len , other = 0.0 )
@@ -71,9 +79,9 @@ def _fwd_kernel(
7179
7280 # update acc
7381 off_v = (
74- cur_batch * q_stride_b
75- + (start_n + offs_n [:, None ]) * q_stride_s
76- + cur_head * q_stride_h
82+ cur_batch * v_stride_b
83+ + (start_n + offs_n [:, None ]) * v_stride_s
84+ + cur_head * v_stride_h
7785 + offs_d [None , :]
7886 )
7987 v = tl .load (V + off_v , mask = (start_n + offs_n [:, None ]) < seq_len , other = 0.0 )
@@ -104,8 +112,8 @@ def flash_attention_fwd(
104112 batch_size , seq_len , head_num , head_dim = q .shape
105113
106114 sm_scale = 1.0 / (head_dim ** 0.5 ) # 计算scale系数
107- grid = (batch_size , head_num , triton .cdiv (seq_len , BLOCK )) # batch, head,
108- # grid = (triton.cdiv(seq_len, BLOCK), batch_size, head_num ) # batch, head,
115+ # grid = (batch_size, head_num, triton.cdiv(seq_len, BLOCK)) # batch, head,
116+ grid = (triton .cdiv (seq_len , BLOCK ), head_num , batch_size ) # batch, head,
109117 num_warps = 4
110118 _fwd_kernel [grid ](
111119 q ,
@@ -118,6 +126,14 @@ def flash_attention_fwd(
118126 q .stride (1 ),
119127 q .stride (2 ),
120128 q .stride (3 ),
129+ k .stride (0 ),
130+ k .stride (1 ),
131+ k .stride (2 ),
132+ k .stride (3 ),
133+ v .stride (0 ),
134+ v .stride (1 ),
135+ v .stride (2 ),
136+ v .stride (3 ),
121137 o .stride (0 ),
122138 o .stride (1 ),
123139 o .stride (2 ),
@@ -157,7 +173,6 @@ def test():
157173 k = torch .empty ((B , L , H , D ), dtype = dtype , device = "cuda" ).normal_ (mean = 0.1 , std = 0.2 )
158174 v = torch .empty ((B , L , H , D ), dtype = dtype , device = "cuda" ).normal_ (mean = 0.1 , std = 0.2 )
159175 o = torch .empty ((B , L , H , D ), dtype = dtype , device = "cuda" ).normal_ (mean = 0.1 , std = 0.2 )
160-
161176 torch_out = torch_att (q , k , v )
162177 import time
163178
@@ -174,6 +189,3 @@ def test():
174189 print ("max " , torch .max (torch .abs (torch_out - o )))
175190 print ("mean " , torch .mean (torch .abs (torch_out - o )))
176191 assert torch .allclose (torch_out , o , atol = 1e-2 , rtol = 0 )
177-
178-
179- # test()
0 commit comments