Skip to content

Commit 0cea768

Browse files
authored
[RUNTIME] Pass full kwargs to Autotuner hooks instead of positional args (#5083)
Currently, when there are hooks accessing args in the `@triton.autotuner` (e.g., `pre_hook` / `post_hook` or those created from the `restore_value` or `reset_to_zero`) and the kernel arguments are passed as kwargs, this breaks, because the hooks take positional `args` only. This PR changes the first parameter of the `pre_hook` and `post_hook` of the `Autotuner` from (partial) `args` to (full) `kwargs`. As a result, we now have access to all arguments, positional or keyword, passed to the kernel call. The call sites and docs are updated accordingly. N.B.: This change is BC-breaking! In the signatures of `pre_hook` and `post_hook`, the first parameter type has changed from list (of positional args) to dict (of all kwargs). Submitting, as agreed with @Jokeren and @peterbell10 below. Fixes #5082. See the failing example, error, and code pointers there.
1 parent 57643b3 commit 0cea768

File tree

2 files changed

+31
-25
lines changed

2 files changed

+31
-25
lines changed

python/test/unit/runtime/test_autotuner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLO
3232
assert len(_kernel.cache) == 2
3333

3434

35-
def test_restore(device):
35+
@pytest.mark.parametrize('pass_kwargs_to_kernel', [False, True])
36+
def test_restore(pass_kwargs_to_kernel, device):
3637
N = 1024
3738
src = torch.zeros(N, device=device)
3839

@@ -46,7 +47,10 @@ def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
4647
tl.store(src + offsets, x, mask=offsets < N)
4748

4849
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
49-
_kernel[grid](src, N)
50+
if pass_kwargs_to_kernel:
51+
_kernel[grid](src=src, N=N)
52+
else:
53+
_kernel[grid](src, N)
5054
triton.testing.assert_close(src, torch.ones_like(src))
5155

5256

python/triton/runtime/autotuner.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,36 +44,36 @@ def __init__(
4444
self.arg_names = arg_names
4545

4646
# Reset to zero or restore values
47-
self.reset_idx = []
47+
self.reset_to_zero = []
4848
if reset_to_zero is not None:
49-
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
50-
self.restore_idx = []
49+
self.reset_to_zero = list(reset_to_zero)
50+
self.restore_value = []
5151
if restore_value is not None:
52-
self.restore_idx = [arg_names.index(k) for k in restore_value]
52+
self.restore_value = list(restore_value)
5353

5454
# Hook to reset or restore for required tensors
55-
self.pre_hook = lambda args, reset_only=False: 0
56-
self.post_hook = lambda args, exception: 0
55+
self.pre_hook = lambda kwargs, reset_only=False: 0
56+
self.post_hook = lambda kwargs, exception: 0
5757
if pre_hook:
5858
self.pre_hook = pre_hook
59-
elif (len(self.reset_idx) > 0 or len(self.restore_idx) > 0):
59+
elif (len(self.reset_to_zero) > 0 or len(self.restore_value) > 0):
6060

61-
def _pre_hook(args, reset_only=False):
62-
for i in self.reset_idx:
63-
args[i].zero_()
61+
def _pre_hook(kwargs, reset_only=False):
62+
for name in self.reset_to_zero:
63+
kwargs[name].zero_()
6464
if not reset_only:
65-
self.restore_copies = [args[i].clone() for i in self.restore_idx]
65+
self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value}
6666

6767
self.pre_hook = _pre_hook
6868

6969
if post_hook:
7070
self.post_hook = post_hook
71-
elif len(self.restore_idx) > 0:
71+
elif len(self.restore_value) > 0:
7272

73-
def _post_hook(args, exception):
74-
for i, j in enumerate(self.restore_idx):
75-
args[j].copy_(self.restore_copies[i])
76-
self.restore_copies = []
73+
def _post_hook(kwargs, exception):
74+
for name in self.restore_value:
75+
kwargs[name].copy_(self.restore_copies[name])
76+
self.restore_copies = {}
7777

7878
self.post_hook = _post_hook
7979

@@ -140,20 +140,20 @@ def _bench(self, *args, config, **meta):
140140
def kernel_call():
141141
if config.pre_hook:
142142
config.pre_hook(full_nargs)
143-
self.pre_hook(args)
143+
self.pre_hook(full_nargs)
144144
try:
145145
self.fn.run(
146146
*args,
147147
**current,
148148
)
149149
except Exception as e:
150150
try:
151-
self.post_hook(args, exception=e)
151+
self.post_hook(full_nargs, exception=e)
152152
finally:
153153
# Throw exception raised by `self.fn.run`
154154
raise
155155

156-
self.post_hook(args, exception=None)
156+
self.post_hook(full_nargs, exception=None)
157157

158158
try:
159159
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
@@ -180,7 +180,8 @@ def run(self, *args, **kwargs):
180180
bench_end = time.time()
181181
self.bench_time = bench_end - bench_start
182182
self.cache[key] = builtins.min(timings, key=timings.get)
183-
self.pre_hook(args, reset_only=True)
183+
full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
184+
self.pre_hook(full_nargs, reset_only=True)
184185
self.configs_timings = timings
185186
config = self.cache[key]
186187
else:
@@ -190,7 +191,8 @@ def run(self, *args, **kwargs):
190191
print(f"Triton autotuning for function {self.base_fn.__name__} finished after "
191192
f"{self.bench_time:.2f}s; best config selected: {self.best_config};")
192193
if config.pre_hook is not None:
193-
config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()})
194+
full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
195+
config.pre_hook(full_nargs)
194196
ret = self.fn.run(
195197
*args,
196198
**kwargs,
@@ -326,12 +328,12 @@ def kernel(x_ptr, x_size, **META):
326328
:type restore_value: list[str]
327329
:param pre_hook: a function that will be called before the kernel is called.
328330
This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'.
329-
'args': a list of arguments passed to the kernel.
331+
'kwargs': a dict of all arguments passed to the kernel.
330332
'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook.
331333
:type pre_hook: lambda args, reset_only
332334
:param post_hook: a function that will be called after the kernel is called.
333335
This overrides the default post_hook used for 'restore_value'.
334-
'args': a list of arguments passed to the kernel.
336+
'kwargs': a dict of all arguments passed to the kernel.
335337
'exception': the exception raised by the kernel in case of a compilation or runtime error.
336338
:type post_hook: lambda args, exception
337339
:param warmup: warmup time (in ms) to pass to benchmarking (deprecated).

0 commit comments

Comments
 (0)