Skip to content

Commit 184fb53

Browse files
authored
[AUTOTUNER] Fix issue in autotuner which may use the wrong value as the key of config cache. (#4808)
The autotuner uses the index of the arguments of the Triton kernel signature to look up the value to be used as the key for the config cache. There is an issue if the user pass the kernel arguments as keyword args in arbitrary order. The name of the argument should be used to look up the value of the args passed by the user instead of the `key_idx`. This prevents the autotuner from using a mismatched value as the key for caching when the arguments are passed in arbitrary order as keyword args. - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [X] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [X] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 1b0f9ea commit 184fb53

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

python/test/unit/runtime/test_autotuner.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,25 @@
77

88
@pytest.mark.parametrize('use_cuda_graph', [False, True])
99
def test_kwargs(use_cuda_graph: bool, device: str):
10-
N = 1024
11-
src = torch.randn(N, device=device)
12-
dst = torch.empty(N, device=device)
10+
M, N = 1024, 16
11+
src = torch.randn(M * N, device=device)
12+
dst = torch.empty(M * N, device=device)
1313

14-
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
14+
configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})]
1515

16-
@triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph)
16+
@triton.autotune(configs=configs, key=['M'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph)
1717
@triton.jit
18-
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
19-
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
20-
x = tl.load(src + offsets, mask=offsets < N)
21-
tl.store(dst + offsets, x, mask=offsets < N)
22-
23-
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
24-
_kernel[grid](dst, src, N)
25-
_kernel[grid](dst=dst, src=src, N=N)
18+
def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr):
19+
offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M)
20+
offsets_n = tl.arange(0, BLOCK_SIZE_N)
21+
x = tl.load(src + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :])
22+
tl.store(dst + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :], x)
23+
24+
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE_M']), )
25+
_kernel[grid](dst, src, N, M, N)
26+
# the key word args could be in arbitrary order.
27+
_kernel[grid](dst=dst, src=src, M=M // 2, stride_m=N, BLOCK_SIZE_N=N)
28+
assert len(_kernel.cache) == 2
2629

2730

2831
def test_restore(device):

python/triton/runtime/autotuner.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)]
3939
else:
4040
self.configs = configs
41-
self.key_idx = [arg_names.index(k) for k in key]
41+
self.keys = key
4242
self.cache = {}
4343
self.arg_names = arg_names
4444

@@ -136,12 +136,9 @@ def run(self, *args, **kwargs):
136136
used_cached_result = True
137137
if len(self.configs) > 1:
138138
all_args = {**self.nargs, **kwargs}
139-
_args = []
140-
for name in self.arg_names:
141-
if name in all_args:
142-
_args.append(all_args[name])
143-
key = [_args[i] for i in self.key_idx]
144-
for arg in _args:
139+
_args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
140+
key = [_args[key] for key in self.keys if key in _args]
141+
for _, arg in _args.items():
145142
if hasattr(arg, "dtype"):
146143
key.append(str(arg.dtype))
147144
key = tuple(key)

0 commit comments

Comments
 (0)