Skip to content

Commit 06f1b52

Browse files
authored
Add metaclass [] syntax for cache classes (#415)
1 parent e42c765 commit 06f1b52

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

helion/autotuner/base_cache.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import logging
88
import os
99
from typing import TYPE_CHECKING
10+
from typing import Any
11+
from typing import Callable
1012
from typing import Hashable
1113

1214
from torch._inductor.codecache import build_code_hash
@@ -16,12 +18,36 @@
1618
from .base_search import BaseAutotuner
1719

1820
if TYPE_CHECKING:
21+
from collections.abc import Sequence
22+
1923
from ..runtime.config import Config
24+
from ..runtime.kernel import BoundKernel
2025
from .base_search import BaseSearch
2126

2227
log: logging.Logger = logging.getLogger(__name__)
2328

2429

30+
class AutotuneCacheMeta(abc.ABCMeta):
31+
"""Metaclass that enables the Cache[Search] syntax for autotuner cache classes."""
32+
33+
def __getitem__(
34+
cls, search_cls: type[BaseSearch]
35+
) -> Callable[[BoundKernel, Sequence[Any]], BaseAutotuner]:
36+
"""Enable Cache[Search] syntax to create a factory function.
37+
38+
Args:
39+
search_cls: The search class to use with this cache
40+
41+
Returns:
42+
A factory function that creates cache instances with the specified search
43+
"""
44+
45+
def factory(kernel: BoundKernel, args: Sequence[Any]) -> BaseAutotuner:
46+
return cls(search_cls(kernel, args)) # type: ignore[misc]
47+
48+
return factory
49+
50+
2551
@functools.cache
2652
def helion_key() -> str:
2753
here = os.path.abspath(__file__)
@@ -107,7 +133,7 @@ class StrictAutotuneCacheKey(LooseAutotuneCacheKey):
107133
triton_key: str = dataclasses.field(default_factory=triton_key_wrapper)
108134

109135

110-
class AutotuneCacheBase(BaseAutotuner, abc.ABC):
136+
class AutotuneCacheBase(BaseAutotuner, abc.ABC, metaclass=AutotuneCacheMeta):
111137
"""
112138
Abstract base class that all autotune caches need to implement.
113139
Any user defined cache will need to extend this class, and

test/test_cache.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ def autotune(self):
2121

2222
class TestCache(RefEagerTestDisabled, TestCase):
2323
def test_basic(self):
24-
@helion.kernel(
25-
autotuner_fn=lambda k, a: StrictLocalAutotuneCache(BasicSearch(k, a))
26-
)
24+
@helion.kernel(autotuner_fn=StrictLocalAutotuneCache[BasicSearch])
2725
def add(x, y):
2826
x, y = torch.broadcast_tensors(x, y)
2927
out = torch.empty_like(x)

0 commit comments

Comments
 (0)