Skip to content

Commit e00903a

Browse files
authored
[RUNTIME] Add flags for detecting user-defined Autotuner hooks (#5092)
When tracing user-defined Triton kernels in PyTorch 2, TorchDynamo intercepts already instantiated `Autotuner` or `JITFunction` objects ([here](https://github.com/pytorch/pytorch/blob/ab42967238483b457af7228b9185b5816b83ed08/torch/_higher_order_ops/triton_kernel_wrap.py#L1034)). This is because the decorator calls are normally made outside the compiled part of the model code, with the latter covering only kernel calls. At this point, we don't support user-defined `Autotuner` `pre_hook` and `post_hook` in PT2 compilation. However, due to the fact that the `pre_hook` and `post_hook` attributes are always set to a lambda ([here](https://github.com/triton-lang/triton/blob/main/python/triton/runtime/autotuner.py#L55-L56)), we can't detect that the hooks are set by the user by introspecting the `Autotuner` object. This leads to silent incorrectness, as the user-defined hooks are simply ignored. This PR adds explicit boolean flags indicating whether the `pre_hook` and / or `post_hook` are set by the user. With those in place, we can detect this and raise a clear error around [here](https://github.com/pytorch/pytorch/blob/ab42967238483b457af7228b9185b5816b83ed08/torch/_higher_order_ops/triton_kernel_wrap.py#L1072-L1074) in TorchDynamo, avoiding silent incorrectness. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `just adding flags`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 4ea942b commit e00903a

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

python/triton/runtime/autotuner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,11 @@ def __init__(
5454
# Hook to reset or restore for required tensors
5555
self.pre_hook = lambda kwargs, reset_only=False: 0
5656
self.post_hook = lambda kwargs, exception: 0
57+
self.user_defined_pre_hook = False
58+
self.user_defined_post_hook = False
5759
if pre_hook:
5860
self.pre_hook = pre_hook
61+
self.user_defined_pre_hook = True
5962
elif (len(self.reset_to_zero) > 0 or len(self.restore_value) > 0):
6063

6164
def _pre_hook(kwargs, reset_only=False):
@@ -68,6 +71,7 @@ def _pre_hook(kwargs, reset_only=False):
6871

6972
if post_hook:
7073
self.post_hook = post_hook
74+
self.user_defined_post_hook = True
7175
elif len(self.restore_value) > 0:
7276

7377
def _post_hook(kwargs, exception):

0 commit comments

Comments
 (0)