|
15 | 15 |
|
16 | 16 | import pytest
|
17 | 17 | import torch
|
18 |
| -import sys |
| 18 | +import os |
19 | 19 |
|
20 | 20 | import triton
|
21 | 21 | import triton.language as tl
|
@@ -101,18 +101,18 @@ def _host_descriptor_pre_hook(nargs):
|
101 | 101 | if is_hip():
|
102 | 102 | NUM_STAGES_OPTIONS = [1]
|
103 | 103 | elif supports_host_descriptor():
|
104 |
| - NUM_STAGES_OPTIONS = [2, 3, 4, 5] |
| 104 | + NUM_STAGES_OPTIONS = [2, 3, 4] |
105 | 105 | else:
|
106 |
| - NUM_STAGES_OPTIONS = [2, 3, 4, 7] |
| 106 | + NUM_STAGES_OPTIONS = [2, 3, 4] |
107 | 107 |
|
108 | 108 | configs = [
|
109 | 109 | triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \
|
110 |
| - for BM in [64, 128, 256]\ |
| 110 | + for BM in [64, 128]\ |
111 | 111 | for BN in [64, 128]\
|
112 | 112 | for s in NUM_STAGES_OPTIONS \
|
113 | 113 | for w in [4, 8]\
|
114 | 114 | ]
|
115 |
| -if "pytest" in sys.modules: |
| 115 | +if "PYTEST_VERSION" in os.environ: |
116 | 116 | # Use a single config in testing for reproducibility
|
117 | 117 | configs = [
|
118 | 118 | triton.Config(dict(BLOCK_M=64, BLOCK_N=64), num_stages=4, num_warps=4, pre_hook=_host_descriptor_pre_hook),
|
@@ -503,6 +503,8 @@ def grid(META):
|
503 | 503 | return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
|
504 | 504 |
|
505 | 505 | ctx.grid = grid
|
| 506 | + if is_cuda() and warp_specialize: |
| 507 | + extra_kern_args["maxnreg"] = 80 |
506 | 508 | _attn_fwd[grid](
|
507 | 509 | sm_scale, M, #
|
508 | 510 | q.shape[0], q.shape[1], #
|
|
0 commit comments