Skip to content

Commit 76045c8

Browse files
authored
[python][typehint] Simplify find_paths_if and add typehints to _utils (#6497)
A while ago I brought up the idea of adding typehints to triton. I started following through on that ambition in triton-lang/triton#6467, and want to continue the endeavor here. I plan to keep taking the 'leafiest' modules and adding typehinting to them. In order to validate types locally I'm using mypy with an lsp connection to neovim. One there's a good amount of typehinting I can go through and add `ignore` statements everywhere there's an error so that we can get typechecking in our pre-commit. For now I'll address type errors ad-hoc as I see them.
1 parent aacd5c0 commit 76045c8

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

python/triton/_utils.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,37 @@
1+
from __future__ import annotations
2+
13
from functools import reduce
4+
from typing import Any, Callable, TYPE_CHECKING, Union
5+
6+
if TYPE_CHECKING:
7+
from .language import core
8+
IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type]
9+
ObjPath = tuple[int, ...]
210

311

4-
def get_iterable_path(iterable, path):
5-
return reduce(lambda a, idx: a[idx], path, iterable)
12+
def get_iterable_path(iterable: IterableType, path: ObjPath) -> Any:
13+
return reduce(lambda a, idx: a[idx], path, iterable) # type: ignore[index]
614

715

8-
def set_iterable_path(iterable, path, val):
16+
def set_iterable_path(iterable: IterableType, path: tuple[int, ...], val: Any):
17+
assert len(path) != 0
918
prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
10-
prev[path[-1]] = val
19+
prev[path[-1]] = val # type: ignore[index]
1120

1221

13-
def find_paths_if(iterable, pred):
22+
def find_paths_if(iterable: Union[IterableType, Any], pred: Callable[[ObjPath, Any], bool]) -> list[ObjPath]:
1423
from .language import core
15-
is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
16-
ret = dict()
24+
is_iterable: Callable[[Any], bool] = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
25+
# We need to use dict so that ordering is maintained, while set doesn't guarantee order
26+
ret: dict[ObjPath, None] = {}
1727

18-
def _impl(current, path):
19-
path = (path[0], ) if len(path) == 1 else tuple(path)
28+
def _impl(path: tuple[int, ...], current: Any):
2029
if is_iterable(current):
2130
for idx, item in enumerate(current):
22-
_impl(item, path + (idx, ))
31+
_impl((*path, idx), item)
2332
elif pred(path, current):
24-
if len(path) == 1:
25-
ret[(path[0], )] = None
26-
else:
27-
ret[tuple(path)] = None
28-
29-
if is_iterable(iterable):
30-
_impl(iterable, [])
31-
elif pred(list(), iterable):
32-
ret = {tuple(): None}
33-
else:
34-
ret = dict()
33+
ret[path] = None
34+
35+
_impl((), iterable)
36+
3537
return list(ret.keys())

0 commit comments

Comments
 (0)