Skip to content

Commit 5132916

Browse files
authored
[frontend][typing] Adding typing & small fix for autotuner.py (#5459)
Hello everyone! # Context I've been working on improving the compatibility of torch.compile with user-defined Triton kernels and recently ran into a type checking (and potential correctness) issue [here](https://github.com/pytorch/pytorch/pull/142207/files#diff-f5fa7d0e418e91c63fa56d577a92a294c87e19318a3c8b3736ac4254eaa51db9R1054). This doesn't type the entire file, but it fixes several typing issues in autotuner.py. There is a plan to revisit this file and other user-facing portions of the API---but since this catches a potential bug I thought it would be worth filing this sooner rather than later. # Testing I ran the pre-commit hooks and `python3 -m python/test/unit`. I ran mypy with the following command: ```bash mypy python/triton/runtime/autotuner.py | grep autotune ``` and the python unit tests.
1 parent 9751269 commit 5132916

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

python/triton/runtime/autotuner.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import time
66
import inspect
7-
from typing import Dict
7+
from typing import Dict, Tuple, List, Optional
88

99
from .jit import KernelInterface
1010
from .errors import OutOfResources, PTXASError
@@ -23,7 +23,7 @@ def __init__(
2323
restore_value,
2424
pre_hook=None,
2525
post_hook=None,
26-
prune_configs_by: Dict = None,
26+
prune_configs_by: Optional[Dict] = None,
2727
warmup=None,
2828
rep=None,
2929
use_cuda_graph=False,
@@ -40,7 +40,7 @@ def __init__(
4040
else:
4141
self.configs = configs
4242
self.keys = key
43-
self.cache = {}
43+
self.cache: Dict[Tuple, Config] = {}
4444
self.arg_names = arg_names
4545

4646
# Reset to zero or restore values
@@ -211,14 +211,18 @@ def run(self, *args, **kwargs):
211211
self.nargs = None
212212
return ret
213213

214-
def prune_configs(self, kwargs):
214+
def prune_configs(self, kwargs: Dict) -> List[Config]:
215215
pruned_configs = self.configs
216216
if self.early_config_prune:
217217
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
218218
if self.perf_model:
219219
top_k = self.configs_top_k
220220
if isinstance(top_k, float) and top_k <= 1.0:
221221
top_k = int(len(self.configs) * top_k)
222+
elif not isinstance(top_k, int):
223+
# Slice index must be an integer
224+
raise TypeError(f"Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int")
225+
222226
if len(pruned_configs) > top_k:
223227
est_timing = {
224228
config: self.perf_model(

0 commit comments

Comments
 (0)