|
15 | 15 |
|
16 | 16 | import pytest |
17 | 17 | import torch |
18 | | -import sys |
| 18 | +import os |
19 | 19 |
|
20 | 20 | import triton |
21 | 21 | import triton.language as tl |
@@ -101,18 +101,18 @@ def _host_descriptor_pre_hook(nargs): |
101 | 101 | if is_hip(): |
102 | 102 | NUM_STAGES_OPTIONS = [1] |
103 | 103 | elif supports_host_descriptor(): |
104 | | - NUM_STAGES_OPTIONS = [2, 3, 4, 5] |
| 104 | + NUM_STAGES_OPTIONS = [2, 3, 4] |
105 | 105 | else: |
106 | | - NUM_STAGES_OPTIONS = [2, 3, 4, 7] |
| 106 | + NUM_STAGES_OPTIONS = [2, 3, 4] |
107 | 107 |
|
108 | 108 | configs = [ |
109 | 109 | triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \ |
110 | | - for BM in [64, 128, 256]\ |
| 110 | + for BM in [64, 128]\ |
111 | 111 | for BN in [64, 128]\ |
112 | 112 | for s in NUM_STAGES_OPTIONS \ |
113 | 113 | for w in [4, 8]\ |
114 | 114 | ] |
115 | | -if "pytest" in sys.modules: |
| 115 | +if "PYTEST_VERSION" in os.environ: |
116 | 116 | # Use a single config in testing for reproducibility |
117 | 117 | configs = [ |
118 | 118 | triton.Config(dict(BLOCK_M=64, BLOCK_N=64), num_stages=4, num_warps=4, pre_hook=_host_descriptor_pre_hook), |
@@ -503,6 +503,8 @@ def grid(META): |
503 | 503 | return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1) |
504 | 504 |
|
505 | 505 | ctx.grid = grid |
| 506 | + if is_cuda() and warp_specialize: |
| 507 | + extra_kern_args["maxnreg"] = 80 |
506 | 508 | _attn_fwd[grid]( |
507 | 509 | sm_scale, M, # |
508 | 510 | q.shape[0], q.shape[1], # |
|
0 commit comments