|
7 | 7 | import logging
|
8 | 8 | import os
|
9 | 9 | from typing import TYPE_CHECKING
|
| 10 | +from typing import Any |
| 11 | +from typing import Callable |
10 | 12 | from typing import Hashable
|
11 | 13 |
|
12 | 14 | from torch._inductor.codecache import build_code_hash
|
|
16 | 18 | from .base_search import BaseAutotuner
|
17 | 19 |
|
18 | 20 | if TYPE_CHECKING:
|
| 21 | + from collections.abc import Sequence |
| 22 | + |
19 | 23 | from ..runtime.config import Config
|
| 24 | + from ..runtime.kernel import BoundKernel |
20 | 25 | from .base_search import BaseSearch
|
21 | 26 |
|
22 | 27 | log: logging.Logger = logging.getLogger(__name__)
|
23 | 28 |
|
24 | 29 |
|
| 30 | +class AutotuneCacheMeta(abc.ABCMeta): |
| 31 | + """Metaclass that enables the Cache[Search] syntax for autotuner cache classes.""" |
| 32 | + |
| 33 | + def __getitem__( |
| 34 | + cls, search_cls: type[BaseSearch] |
| 35 | + ) -> Callable[[BoundKernel, Sequence[Any]], BaseAutotuner]: |
| 36 | + """Enable Cache[Search] syntax to create a factory function. |
| 37 | +
|
| 38 | + Args: |
| 39 | + search_cls: The search class to use with this cache |
| 40 | +
|
| 41 | + Returns: |
| 42 | + A factory function that creates cache instances with the specified search |
| 43 | + """ |
| 44 | + |
| 45 | + def factory(kernel: BoundKernel, args: Sequence[Any]) -> BaseAutotuner: |
| 46 | + return cls(search_cls(kernel, args)) # type: ignore[misc] |
| 47 | + |
| 48 | + return factory |
| 49 | + |
| 50 | + |
25 | 51 | @functools.cache
|
26 | 52 | def helion_key() -> str:
|
27 | 53 | here = os.path.abspath(__file__)
|
@@ -107,7 +133,7 @@ class StrictAutotuneCacheKey(LooseAutotuneCacheKey):
|
107 | 133 | triton_key: str = dataclasses.field(default_factory=triton_key_wrapper)
|
108 | 134 |
|
109 | 135 |
|
110 |
| -class AutotuneCacheBase(BaseAutotuner, abc.ABC): |
| 136 | +class AutotuneCacheBase(BaseAutotuner, abc.ABC, metaclass=AutotuneCacheMeta): |
111 | 137 | """
|
112 | 138 | Abstract base class that all autotune caches need to implement.
|
113 | 139 | Any user defined cache will need to extend this class, and
|
|
0 commit comments