Skip to content

Commit 334186f

Browse files
Mogballwhitneywhtsang
authored andcommitted
[Pipeliner] Merge warp specialization and pipeliner scheduling (#6887)
This PR refactors warp specialization to share the same scheduling as software pipelining. What this means is that the pipeliner's loop scheduler is used to set the stages and clusters of the ops, then on top of that, warp specialization will perform partition assignment and split the loop, introducing synchronization, into multiple loops that are then individually software pipelined.
1 parent e6c4956 commit 334186f

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import pytest
1717
import torch
18-
import sys
18+
import os
1919

2020
import triton
2121
import triton.language as tl
@@ -101,18 +101,18 @@ def _host_descriptor_pre_hook(nargs):
101101
if is_hip():
102102
NUM_STAGES_OPTIONS = [1]
103103
elif supports_host_descriptor():
104-
NUM_STAGES_OPTIONS = [2, 3, 4, 5]
104+
NUM_STAGES_OPTIONS = [2, 3, 4]
105105
else:
106-
NUM_STAGES_OPTIONS = [2, 3, 4, 7]
106+
NUM_STAGES_OPTIONS = [2, 3, 4]
107107

108108
configs = [
109109
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \
110-
for BM in [64, 128, 256]\
110+
for BM in [64, 128]\
111111
for BN in [64, 128]\
112112
for s in NUM_STAGES_OPTIONS \
113113
for w in [4, 8]\
114114
]
115-
if "pytest" in sys.modules:
115+
if "PYTEST_VERSION" in os.environ:
116116
# Use a single config in testing for reproducibility
117117
configs = [
118118
triton.Config(dict(BLOCK_M=64, BLOCK_N=64), num_stages=4, num_warps=4, pre_hook=_host_descriptor_pre_hook),
@@ -503,6 +503,8 @@ def grid(META):
503503
return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
504504

505505
ctx.grid = grid
506+
if is_cuda() and warp_specialize:
507+
extra_kern_args["maxnreg"] = 80
506508
_attn_fwd[grid](
507509
sm_scale, M, #
508510
q.shape[0], q.shape[1], #

0 commit comments

Comments
 (0)