@@ -156,7 +156,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
156156 triton .Config ({'BLOCK_M' : BM , 'BLOCK_N' : BN , 'grf_mode' : 'large' }, num_stages = s , num_warps = w ) \
157157 for BM in [128 , 256 ] \
158158 for BN in [32 , 64 ] \
159- for s in [3 , 4 ] \
159+ for s in [2 , 3 , 4 ] \
160160 for w in [8 , 16 , 32 ] \
161161 ]
162162
@@ -170,43 +170,25 @@ def forward(q, k, v, causal, sm_scale):
170170 assert Lq == Lk and Lk == Lv
171171 assert Lk in {16 , 32 , 64 , 128 }
172172 o = torch .empty_like (q , dtype = torch .float32 )
173- BLOCK_M = 128
174- BLOCK_N = 64 if Lk <= 64 else 32
175- num_stages = 3
176- num_warps = 8 if Lq == 64 else 16
173+ # BLOCK_M = 128
174+ # BLOCK_N = 64 if Lk <= 64 else 32
175+ # num_stages = 3
176+ # num_warps = 8 if Lq == 64 else 16
177177 stage = 3 if causal else 1
178178 grid = lambda args : (q .shape [0 ], q .shape [1 ], triton .cdiv (q .shape [2 ], args ['BLOCK_M' ]))
179179 M = torch .empty ((q .shape [0 ], q .shape [1 ], q .shape [2 ]), device = q .device , dtype = torch .float32 )
180180
181- if os .getenv ('TRITON_INTEL_ADVANCED_PATH' , '0' ) == '0' :
182- # default pipeline
183- tune_attn_fwd [grid ](
184- q , k , v , sm_scale , M , o , #
185- q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
186- k .stride (0 ), k .stride (1 ), k .stride (2 ), k .stride (3 ), #
187- v .stride (0 ), v .stride (1 ), v .stride (2 ), v .stride (3 ), #
188- o .stride (0 ), o .stride (1 ), o .stride (2 ), o .stride (3 ), #
189- q .shape [0 ], q .shape [1 ], #
190- N_CTX = q .shape [2 ], #
191- BLOCK_DMODEL = Lk , #
192- STAGE = stage , #
193- )
194- else :
195- _attn_fwd [grid ](
196- q , k , v , sm_scale , M , o , #
197- q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
198- k .stride (0 ), k .stride (1 ), k .stride (2 ), k .stride (3 ), #
199- v .stride (0 ), v .stride (1 ), v .stride (2 ), v .stride (3 ), #
200- o .stride (0 ), o .stride (1 ), o .stride (2 ), o .stride (3 ), #
201- q .shape [0 ], q .shape [1 ], #
202- N_CTX = q .shape [2 ], #
203- BLOCK_M = BLOCK_M , #
204- BLOCK_N = BLOCK_N , #
205- BLOCK_DMODEL = Lk , #
206- STAGE = stage , #
207- num_warps = num_warps , #
208- num_stages = num_stages #
209- )
181+ tune_attn_fwd [grid ](
182+ q , k , v , sm_scale , M , o , #
183+ q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
184+ k .stride (0 ), k .stride (1 ), k .stride (2 ), k .stride (3 ), #
185+ v .stride (0 ), v .stride (1 ), v .stride (2 ), v .stride (3 ), #
186+ o .stride (0 ), o .stride (1 ), o .stride (2 ), o .stride (3 ), #
187+ q .shape [0 ], q .shape [1 ], #
188+ N_CTX = q .shape [2 ], #
189+ BLOCK_DMODEL = Lk , #
190+ STAGE = stage , #
191+ )
210192 return o
211193
212194
0 commit comments