Skip to content

Commit 8557148

Browse files
authored
[typehint][knobs] Specify values in metadata_group as str (#6774)
I didn't catch this on triton-lang/triton#6364 -- `Any` is fine, but we know the values are `str` from [`get_group`](https://github.com/triton-lang/triton/blob/b73d59597370a5ea7bd084d0f047516b029d9f4e/python/triton/runtime/cache.py#L26)
1 parent b73d595 commit 8557148

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

python/test/unit/runtime/test_compilation_listener.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def cumsum_kernel(ptr):
2020
def test_compile_stats(device: str, fresh_knobs_except_libraries: Any, fresh_triton_cache: str) -> None:
2121
captured: Union[tuple[Union[ASTSource, IRSource], dict[str, Any], dict[str, Any], CompileTimes, bool], None] = None
2222

23-
def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, Any], metadata_group: dict[str, Any],
23+
def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, str], metadata_group: dict[str, Any],
2424
times: CompileTimes, cache_hit: bool) -> None:
2525
nonlocal captured
2626
assert captured is None

python/triton/knobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def total(self) -> int:
253253

254254
class CompilationListener(Protocol):
255255

256-
def __call__(self, *, src: Union[ASTSource, IRSource], metadata: dict[str, Any], metadata_group: dict[str, Any],
256+
def __call__(self, *, src: Union[ASTSource, IRSource], metadata: dict[str, Any], metadata_group: dict[str, str],
257257
times: CompileTimes, cache_hit: bool) -> None:
258258
...
259259

0 commit comments

Comments
 (0)