1+ import os
12import torch
23import triton
34import triton .language as tl
@@ -151,6 +152,18 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
151152 tl .store (O_block_ptr , acc .to (Out .type .element_ty ))
152153
153154
155+ configs = [
156+ triton .Config ({'BLOCK_M' : BM , 'BLOCK_N' : BN }, num_stages = s , num_warps = w ) \
157+ for BM in [256 ] \
158+ for BN in [32 , 64 ] \
159+ for s in [3 ] \
160+ for w in [32 ] \
161+ ]
162+
163+ tuner = triton .autotune (configs , key = ['N_CTX' , 'BLOCK_DMODEL' ])
164+ tune_attn_fwd = tuner (_attn_fwd )
165+
166+
154167def forward (q , k , v , causal , sm_scale ):
155168 # shape constraints
156169 Lq , Lk , Lv = q .shape [- 1 ], k .shape [- 1 ], v .shape [- 1 ]
@@ -162,23 +175,38 @@ def forward(q, k, v, causal, sm_scale):
162175 num_stages = 3
163176 num_warps = 8 if Lq == 64 else 16
164177 stage = 3 if causal else 1
165- grid = (q .shape [0 ], q .shape [1 ], triton .cdiv (q .shape [2 ], BLOCK_M ))
178+ grid = lambda args : (q .shape [0 ], q .shape [1 ], triton .cdiv (q .shape [2 ], args [ ' BLOCK_M' ] ))
166179 M = torch .empty ((q .shape [0 ], q .shape [1 ], q .shape [2 ]), device = q .device , dtype = torch .float32 )
167- _attn_fwd [grid ](
168- q , k , v , sm_scale , M , o , #
169- q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
170- k .stride (0 ), k .stride (1 ), k .stride (2 ), k .stride (3 ), #
171- v .stride (0 ), v .stride (1 ), v .stride (2 ), v .stride (3 ), #
172- o .stride (0 ), o .stride (1 ), o .stride (2 ), o .stride (3 ), #
173- q .shape [0 ], q .shape [1 ], #
174- N_CTX = q .shape [2 ], #
175- BLOCK_M = BLOCK_M , #
176- BLOCK_N = BLOCK_N , #
177- BLOCK_DMODEL = Lk , #
178- STAGE = stage , #
179- num_warps = num_warps , #
180- num_stages = num_stages #
181- )
180+
181+ if os .getenv ('TRITON_INTEL_ADVANCED_PATH' , '0' ) == '0' :
182+ # default pipeline
183+ tune_attn_fwd [grid ](
184+ q , k , v , sm_scale , M , o , #
185+ q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
186+ k .stride (0 ), k .stride (1 ), k .stride (2 ), k .stride (3 ), #
187+ v .stride (0 ), v .stride (1 ), v .stride (2 ), v .stride (3 ), #
188+ o .stride (0 ), o .stride (1 ), o .stride (2 ), o .stride (3 ), #
189+ q .shape [0 ], q .shape [1 ], #
190+ N_CTX = q .shape [2 ], #
191+ BLOCK_DMODEL = Lk , #
192+ STAGE = stage , #
193+ )
194+ else :
195+ _attn_fwd [grid ](
196+ q , k , v , sm_scale , M , o , #
197+ q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
198+ k .stride (0 ), k .stride (1 ), k .stride (2 ), k .stride (3 ), #
199+ v .stride (0 ), v .stride (1 ), v .stride (2 ), v .stride (3 ), #
200+ o .stride (0 ), o .stride (1 ), o .stride (2 ), o .stride (3 ), #
201+ q .shape [0 ], q .shape [1 ], #
202+ N_CTX = q .shape [2 ], #
203+ BLOCK_M = BLOCK_M , #
204+ BLOCK_N = BLOCK_N , #
205+ BLOCK_DMODEL = Lk , #
206+ STAGE = stage , #
207+ num_warps = num_warps , #
208+ num_stages = num_stages #
209+ )
182210 return o
183211
184212
@@ -243,7 +271,6 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
243271 elif provider == 'triton' :
244272 # FIXME: remove below if condition when extend attention support for Causal = True done
245273 # https://github.com/intel/intel-xpu-backend-for-triton/issues/1102
246- import os
247274 if os .environ .get ('TRITON_INTEL_ADVANCED_PATH' , '0' ) == '1' and CAUSAL :
248275 min_ms , max_ms , mean , cv = (float ('inf' ), ) * 4
249276 else :
0 commit comments