@@ -56,6 +56,7 @@ def _attn_fwd_subtile(
56
56
v ,
57
57
dtype : tl .constexpr ,
58
58
STAGE : tl .constexpr ,
59
+ SUBTILING : tl .constexpr ,
59
60
):
60
61
qk = tl .dot (q , k )
61
62
if STAGE == 2 :
@@ -75,10 +76,13 @@ def _attn_fwd_subtile(
75
76
BM : tl .constexpr = acc .shape [0 ]
76
77
BN : tl .constexpr = acc .shape [1 ]
77
78
78
- acc0 , acc1 = acc .reshape ([BM , 2 , BN // 2 ]).permute (0 , 2 , 1 ).split ()
79
- acc0 = acc0 * alpha [:, None ]
80
- acc1 = acc1 * alpha [:, None ]
81
- acc = tl .join (acc0 , acc1 ).permute (0 , 2 , 1 ).reshape ([BM , BN ])
79
+ if SUBTILING :
80
+ acc0 , acc1 = acc .reshape ([BM , 2 , BN // 2 ]).permute (0 , 2 , 1 ).split ()
81
+ acc0 = acc0 * alpha [:, None ]
82
+ acc1 = acc1 * alpha [:, None ]
83
+ acc = tl .join (acc0 , acc1 ).permute (0 , 2 , 1 ).reshape ([BM , BN ])
84
+ else :
85
+ acc = acc * alpha [:, None ]
82
86
83
87
# prepare p and v for the dot
84
88
p = p .to (dtype )
@@ -117,6 +121,7 @@ def _attn_fwd_inner_oss_dp(
117
121
offs_n : tl .constexpr , #
118
122
N_CTX : tl .constexpr ,
119
123
warp_specialize : tl .constexpr ,
124
+ SUBTILING : tl .constexpr ,
120
125
):
121
126
# range of values handled by this stage
122
127
if STAGE == 1 :
@@ -139,10 +144,34 @@ def _attn_fwd_inner_oss_dp(
139
144
v = desc_v .load ([offsetkv_y , 0 ])
140
145
141
146
l_i0 , m_i0 , acc0 = _attn_fwd_subtile (
142
- q0 , k , offs_m0 , start_n , offs_n , qk_scale , l_i0 , m_i0 , acc0 , v , dtype , STAGE
147
+ q0 ,
148
+ k ,
149
+ offs_m0 ,
150
+ start_n ,
151
+ offs_n ,
152
+ qk_scale ,
153
+ l_i0 ,
154
+ m_i0 ,
155
+ acc0 ,
156
+ v ,
157
+ dtype ,
158
+ STAGE ,
159
+ SUBTILING ,
143
160
)
144
161
l_i1 , m_i1 , acc1 = _attn_fwd_subtile (
145
- q1 , k , offs_m1 , start_n , offs_n , qk_scale , l_i1 , m_i1 , acc1 , v , dtype , STAGE
162
+ q1 ,
163
+ k ,
164
+ offs_m1 ,
165
+ start_n ,
166
+ offs_n ,
167
+ qk_scale ,
168
+ l_i1 ,
169
+ m_i1 ,
170
+ acc1 ,
171
+ v ,
172
+ dtype ,
173
+ STAGE ,
174
+ SUBTILING ,
146
175
)
147
176
148
177
offsetkv_y += BLOCK_N
@@ -174,15 +203,17 @@ def _host_descriptor_pre_hook(nargs):
174
203
175
204
configs = [
176
205
triton .Config (
177
- {"BLOCK_M" : BM , "BLOCK_N" : BN },
206
+ {"BLOCK_M" : BM , "BLOCK_N" : BN , "SUBTILING" : subtile },
178
207
num_stages = s ,
179
208
num_warps = w ,
180
209
pre_hook = _host_descriptor_pre_hook ,
210
+ # ir_override=f"/home/mren/OpenSource/tritonbench/override/_attn_fwd_persist.ttgir"
181
211
)
182
212
for BM in [256 ]
183
213
for BN in [128 ]
184
214
for s in NUM_STAGES_OPTIONS
185
215
for w in [4 ]
216
+ for subtile in [True ]
186
217
]
187
218
188
219
@@ -222,6 +253,8 @@ def _attn_fwd_tma_dp(
222
253
desc_k ,
223
254
desc_v ,
224
255
desc_o ,
256
+ pid ,
257
+ off_hz ,
225
258
N_CTX , #
226
259
HEAD_DIM : tl .constexpr , #
227
260
BLOCK_M : tl .constexpr , #
@@ -230,10 +263,11 @@ def _attn_fwd_tma_dp(
230
263
STAGE : tl .constexpr , #
231
264
warp_specialize : tl .constexpr , #
232
265
dtype : tl .constexpr ,
266
+ SUBTILING : tl .constexpr ,
233
267
):
234
268
tl .static_assert (BLOCK_N <= HEAD_DIM )
235
- start_m = tl .program_id (0 )
236
- off_hz = tl .program_id (1 )
269
+ start_m = pid # tl.program_id(0)
270
+ # off_hz = tl.program_id(1)
237
271
off_z = off_hz // H
238
272
off_h = off_hz % H
239
273
@@ -283,6 +317,7 @@ def _attn_fwd_tma_dp(
283
317
offs_n ,
284
318
N_CTX , #
285
319
warp_specialize ,
320
+ SUBTILING ,
286
321
)
287
322
if STAGE & 2 :
288
323
acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 = _attn_fwd_inner_oss_dp (
@@ -309,6 +344,7 @@ def _attn_fwd_tma_dp(
309
344
offs_n ,
310
345
N_CTX , #
311
346
warp_specialize ,
347
+ SUBTILING ,
312
348
)
313
349
314
350
m_i0 += tl .math .log2 (l_i0 )
@@ -324,6 +360,56 @@ def _attn_fwd_tma_dp(
324
360
desc_o .store ([qo_offset_y + BLOCK_M // 2 , 0 ], acc1 .to (dtype ))
325
361
326
362
363
+ @triton .autotune (
364
+ configs = list (filter (keep , configs )),
365
+ key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" , "warp_specialize" ],
366
+ prune_configs_by = {"early_config_prune" : prune_invalid_configs },
367
+ )
368
+ @triton .jit
369
+ def _attn_fwd (
370
+ sm_scale ,
371
+ M , #
372
+ Z ,
373
+ H ,
374
+ desc_q ,
375
+ desc_k ,
376
+ desc_v ,
377
+ desc_o ,
378
+ N_CTX , #
379
+ HEAD_DIM : tl .constexpr , #
380
+ BLOCK_M : tl .constexpr , #
381
+ BLOCK_N : tl .constexpr , #
382
+ FP8_OUTPUT : tl .constexpr , #
383
+ STAGE : tl .constexpr , #
384
+ warp_specialize : tl .constexpr , #
385
+ dtype : tl .constexpr ,
386
+ SUBTILING : tl .constexpr ,
387
+ ):
388
+ pid = tl .program_id (0 )
389
+ off_hz = tl .program_id (1 )
390
+ _attn_fwd_tma_dp (
391
+ sm_scale ,
392
+ M ,
393
+ Z ,
394
+ H ,
395
+ desc_q ,
396
+ desc_k ,
397
+ desc_v ,
398
+ desc_o ,
399
+ pid ,
400
+ off_hz ,
401
+ N_CTX ,
402
+ HEAD_DIM ,
403
+ BLOCK_M ,
404
+ BLOCK_N ,
405
+ FP8_OUTPUT ,
406
+ STAGE ,
407
+ warp_specialize ,
408
+ dtype ,
409
+ SUBTILING ,
410
+ )
411
+
412
+
327
413
@triton .autotune (
328
414
configs = list (filter (keep , configs )),
329
415
key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" , "warp_specialize" ],
@@ -348,6 +434,7 @@ def _attn_fwd_persist(
348
434
warp_specialize : tl .constexpr , #
349
435
OUTER_LOOP : tl .constexpr ,
350
436
dtype : tl .constexpr ,
437
+ SUBTILING : tl .constexpr ,
351
438
):
352
439
n_tile_num = tl .cdiv (N_CTX , BLOCK_M )
353
440
prog_id = tl .program_id (0 )
@@ -372,6 +459,8 @@ def _attn_fwd_persist(
372
459
desc_k ,
373
460
desc_v ,
374
461
desc_o ,
462
+ pid ,
463
+ off_hz ,
375
464
N_CTX ,
376
465
HEAD_DIM ,
377
466
BLOCK_M ,
@@ -380,6 +469,7 @@ def _attn_fwd_persist(
380
469
STAGE ,
381
470
warp_specialize and not OUTER_LOOP ,
382
471
dtype ,
472
+ SUBTILING ,
383
473
)
384
474
tile_idx += num_progs
385
475
@@ -406,6 +496,7 @@ def forward(ctx, q, k, v, causal, sm_scale, baseVariant):
406
496
M = torch .empty (
407
497
(q .shape [0 ], q .shape [1 ], q .shape [2 ]), device = q .device , dtype = torch .float32
408
498
)
499
+ warp_specialize = baseVariant == "ws" or baseVariant == "ws_persistent"
409
500
# Use device_descriptor for Hopper + warpspec.
410
501
if supports_host_descriptor () and not (is_hopper () and warp_specialize ):
411
502
# Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor
@@ -473,33 +564,59 @@ def grid_persist(META):
473
564
1 ,
474
565
)
475
566
567
+ def grid_debug (META ):
568
+ return (
569
+ 1 ,
570
+ 1 ,
571
+ 1 ,
572
+ )
573
+
476
574
ctx .grid = grid
477
- warp_specialize = baseVariant == "ws "
575
+ persistent = baseVariant == "persistent" or baseVariant == "ws_persistent "
478
576
if is_blackwell () and warp_specialize :
479
577
if HEAD_DIM_K == 128 and (
480
578
q .dtype == torch .float16 or q .dtype == torch .bfloat16
481
579
):
482
- extra_kern_args ["maxnreg" ] = 168
580
+ extra_kern_args ["maxnreg" ] = 128
483
581
else :
484
582
extra_kern_args ["maxnreg" ] = 80
485
- _attn_fwd_persist [grid_persist ](
486
- sm_scale ,
487
- M , #
488
- q .shape [0 ],
489
- q .shape [1 ], #
490
- desc_q ,
491
- desc_k ,
492
- desc_v ,
493
- desc_o , #
494
- N_CTX = q .shape [2 ], #
495
- HEAD_DIM = HEAD_DIM_K , #
496
- FP8_OUTPUT = q .dtype == torch .float8_e5m2 , #
497
- STAGE = stage , #
498
- warp_specialize = warp_specialize ,
499
- OUTER_LOOP = True ,
500
- dtype = torch_dtype_to_triton (q .dtype ),
501
- ** extra_kern_args ,
502
- )
583
+ if persistent :
584
+ _attn_fwd_persist [grid_persist ](
585
+ sm_scale ,
586
+ M , #
587
+ q .shape [0 ],
588
+ q .shape [1 ], #
589
+ desc_q ,
590
+ desc_k ,
591
+ desc_v ,
592
+ desc_o , #
593
+ N_CTX = q .shape [2 ], #
594
+ HEAD_DIM = HEAD_DIM_K , #
595
+ FP8_OUTPUT = q .dtype == torch .float8_e5m2 , #
596
+ STAGE = stage , #
597
+ warp_specialize = warp_specialize ,
598
+ OUTER_LOOP = True ,
599
+ dtype = torch_dtype_to_triton (q .dtype ),
600
+ ** extra_kern_args ,
601
+ )
602
+ else :
603
+ _attn_fwd [grid ](
604
+ sm_scale ,
605
+ M , #
606
+ q .shape [0 ],
607
+ q .shape [1 ], #
608
+ desc_q ,
609
+ desc_k ,
610
+ desc_v ,
611
+ desc_o , #
612
+ N_CTX = q .shape [2 ], #
613
+ HEAD_DIM = HEAD_DIM_K , #
614
+ FP8_OUTPUT = q .dtype == torch .float8_e5m2 , #
615
+ STAGE = stage , #
616
+ warp_specialize = warp_specialize ,
617
+ dtype = torch_dtype_to_triton (q .dtype ),
618
+ ** extra_kern_args ,
619
+ )
503
620
504
621
ctx .save_for_backward (q , k , v , o , M )
505
622
0 commit comments