Skip to content

Commit 9064eda

Browse files
authored
[RFC] Implement basic on disk caching (#336)
1 parent 869b590 commit 9064eda

File tree

8 files changed

+374
-9
lines changed

8 files changed

+374
-9
lines changed

helion/_testing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import collections
4+
import contextlib
45
import importlib
56
import inspect
67
import operator
@@ -15,6 +16,7 @@
1516
import torch
1617
from triton.testing import do_bench
1718

19+
from ._utils import counters
1820
from .runtime.config import Config
1921
from helion._compat import get_tensor_descriptor_fn_name
2022

@@ -291,6 +293,20 @@ def tearDownClass(cls) -> None:
291293
super().tearDownClass()
292294
del cls._expected_journal
293295

296+
def setUp(self) -> None:
297+
super().setUp()
298+
self._test_stack = contextlib.ExitStack()
299+
300+
from torch._inductor.utils import fresh_cache
301+
302+
self._test_stack.enter_context(fresh_cache())
303+
304+
counters.clear()
305+
306+
def tearDown(self) -> None:
307+
super().tearDown()
308+
self._test_stack.close()
309+
294310
def assertExpectedJournal(self, value: str) -> None:
295311
"""
296312
Assert that the given value matches the expected output stored in <testfile>.expected.

helion/_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from __future__ import annotations
2+
3+
import collections
4+
5+
counters: collections.defaultdict[str, collections.Counter[str]] = (
6+
collections.defaultdict(collections.Counter)
7+
)

helion/autotuner/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@
99
DifferentialEvolutionSearch as DifferentialEvolutionSearch,
1010
)
1111
from .finite_search import FiniteSearch as FiniteSearch
12+
from .local_cache import LocalAutotuneCache as LocalAutotuneCache
13+
from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache
1214
from .random_search import RandomSearch as RandomSearch

helion/autotuner/base_cache.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from __future__ import annotations
2+
3+
import abc
4+
import dataclasses
5+
import functools
6+
import hashlib
7+
import logging
8+
import os
9+
from typing import TYPE_CHECKING
10+
from typing import Hashable
11+
from typing import Sequence
12+
13+
from torch._inductor.codecache import build_code_hash
14+
from torch._inductor.codecache import torch_key
15+
from torch._inductor.runtime.triton_compat import triton_key
16+
17+
from .._utils import counters
18+
19+
if TYPE_CHECKING:
20+
from ..runtime.config import Config
21+
from ..runtime.kernel import BoundKernel
22+
from .base_search import BaseSearch
23+
24+
log: logging.Logger = logging.getLogger(__name__)
25+
26+
27+
@functools.cache
28+
def helion_key() -> str:
29+
here = os.path.abspath(__file__)
30+
helion_path = os.path.dirname(os.path.dirname(here))
31+
32+
combined_hash = hashlib.sha256()
33+
build_code_hash([helion_path], "", combined_hash)
34+
return combined_hash.hexdigest()
35+
36+
37+
@functools.cache
38+
def torch_key_wrapper() -> str:
39+
return torch_key().hex()
40+
41+
42+
@functools.cache
43+
def triton_key_wrapper() -> str:
44+
return triton_key()
45+
46+
47+
class CacheKeyBase:
48+
"""
49+
Base class to provide utility functions to all cache key dataclasses
50+
"""
51+
52+
def stable_hash(self) -> str:
53+
return hashlib.sha256(repr(self).encode("utf-8")).hexdigest()
54+
55+
56+
@dataclasses.dataclass(frozen=True)
57+
class BoundKernelInMemoryCacheKey(CacheKeyBase):
58+
"""
59+
Default in memory cache key.
60+
61+
This key includes:
62+
63+
specialization_key: Information about all kernel inputs.
64+
For tensors this means their device, shape, size etc.
65+
extra_results: Information regarding `hl.specialize` decisions
66+
"""
67+
68+
specialization_key: tuple[Hashable, ...]
69+
extra_results: tuple[Hashable, ...]
70+
71+
72+
@dataclasses.dataclass(frozen=True)
73+
class LooseAutotuneCacheKey(BoundKernelInMemoryCacheKey):
74+
"""
75+
Autotune Cache key to use for most use cases.
76+
77+
This key includes (in addition to BoundKernelInMemoryCacheKey):
78+
79+
kernel_source_hash: Hash of source code of input Helion kernel
80+
hardware: Hardware of the input device
81+
runtime_name: Version of the cuda/rocm arch
82+
"""
83+
84+
kernel_source_hash: str
85+
hardware: str
86+
runtime_name: str
87+
88+
def stable_hash(self) -> str:
89+
return hashlib.sha256(repr(self).encode("utf-8")).hexdigest()
90+
91+
92+
@dataclasses.dataclass(frozen=True)
93+
class StrictAutotuneCacheKey(LooseAutotuneCacheKey):
94+
"""
95+
Autotune Cache key to use for utmost strictness in terms of re-autotuning
96+
when library source code changes.
97+
98+
This key includes (in addition to StrictAutotuneCacheKey):
99+
100+
helion_key: Hash of source code of Helion
101+
torch_key: Hash of source code of PyTorch
102+
triton_key: Hash of source code of Triton
103+
"""
104+
105+
helion_key: str = dataclasses.field(default_factory=helion_key)
106+
torch_key: str = dataclasses.field(default_factory=torch_key_wrapper)
107+
triton_key: str = dataclasses.field(default_factory=triton_key_wrapper)
108+
109+
110+
class AutotuneCacheBase(abc.ABC):
111+
"""
112+
Abstract base class that all autotune caches need to implement.
113+
Any user defined cache will need to extend this class, and
114+
provide implementations for get and put methods.
115+
"""
116+
117+
def __init__(
118+
self, kernel: BoundKernel, args: Sequence[object], autotuner: BaseSearch
119+
) -> None:
120+
self.autotuner = autotuner
121+
self.kernel = kernel
122+
self.args = args
123+
124+
@abc.abstractmethod
125+
def get(self) -> Config | None:
126+
raise NotImplementedError
127+
128+
@abc.abstractmethod
129+
def put(self, config: Config) -> None:
130+
raise NotImplementedError
131+
132+
def autotune(self) -> Config:
133+
if (config := self.get()) is not None:
134+
counters["autotune"]["cache_hit"] += 1
135+
log.debug("cache hit: %s", str(config))
136+
return config
137+
138+
counters["autotune"]["cache_miss"] += 1
139+
log.debug("cache miss")
140+
141+
config = self.autotuner.autotune()
142+
143+
self.put(config)
144+
counters["autotune"]["cache_put"] += 1
145+
log.debug("cache put: %s", str(config))
146+
147+
return config

helion/autotuner/local_cache.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from __future__ import annotations
2+
3+
import hashlib
4+
import inspect
5+
import logging
6+
import os
7+
from pathlib import Path
8+
import textwrap
9+
from typing import TYPE_CHECKING
10+
from typing import Sequence
11+
12+
import torch
13+
from torch._inductor.runtime.cache_dir_utils import (
14+
cache_dir, # pyright: ignore[reportPrivateImportUsage]
15+
)
16+
17+
from ..runtime.config import Config
18+
from .base_cache import AutotuneCacheBase
19+
from .base_cache import LooseAutotuneCacheKey
20+
from .base_cache import StrictAutotuneCacheKey
21+
22+
if TYPE_CHECKING:
23+
from ..runtime.kernel import BoundKernel
24+
from .base_search import BaseSearch
25+
26+
log: logging.Logger = logging.getLogger(__name__)
27+
28+
29+
class LocalAutotuneCache(AutotuneCacheBase):
30+
"""
31+
This class implements the local autotune cache, storing the
32+
best config artifact on the local file system either by default
33+
on torch's cache directory, or at a user specified HELION_CACHE_DIR
34+
directory.
35+
It uses the LooseAutotuneCacheKey implementation for the cache key
36+
which takes into account device and source code properties, but does
37+
not account for library level code changes such as Triton, Helion or
38+
PyTorch. Use StrictLocalAutotuneCache to consider these properties.
39+
"""
40+
41+
def __init__(
42+
self, kernel: BoundKernel, args: Sequence[object], autotuner: BaseSearch
43+
) -> None:
44+
super().__init__(kernel, args, autotuner)
45+
self.key = self._generate_key()
46+
47+
def _generate_key(self) -> LooseAutotuneCacheKey:
48+
in_memory_cache_key = self.kernel.kernel._create_bound_kernel_cache_key(
49+
self.kernel,
50+
tuple(self.args),
51+
self.kernel.kernel.specialization_key(self.args),
52+
)
53+
kernel_source = textwrap.dedent(inspect.getsource(self.kernel.kernel.fn))
54+
kernel_source_hash = hashlib.sha256(kernel_source.encode("utf-8")).hexdigest()
55+
56+
hardware = None
57+
runtime_name = None
58+
59+
for arg in self.args:
60+
if isinstance(arg, torch.Tensor):
61+
device_properties = torch.cuda.get_device_properties(arg.device)
62+
if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue]
63+
hardware = device_properties.name
64+
runtime_name = torch.version.cuda # pyright: ignore[reportAttributeAccessIssue]
65+
else:
66+
hardware = device_properties.gcnArchName
67+
runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue]
68+
69+
assert hardware is not None and runtime_name is not None
70+
return LooseAutotuneCacheKey(
71+
specialization_key=in_memory_cache_key.specialization_key,
72+
extra_results=in_memory_cache_key.extra_results,
73+
kernel_source_hash=kernel_source_hash,
74+
hardware=hardware,
75+
runtime_name=runtime_name,
76+
)
77+
78+
def _get_local_cache_path(self) -> Path:
79+
if (user_path := os.environ.get("HELION_CACHE_DIR", None)) is not None:
80+
cache_path = Path(user_path)
81+
else:
82+
cache_path = Path(cache_dir()) / "helion"
83+
84+
return cache_path / f"{self.key.stable_hash()}.best_config"
85+
86+
def get(self) -> Config | None:
87+
path = self._get_local_cache_path()
88+
try:
89+
return Config.load(path)
90+
except Exception:
91+
return None
92+
93+
def put(self, config: Config) -> None:
94+
path = self._get_local_cache_path()
95+
config.save(path)
96+
97+
98+
class StrictLocalAutotuneCache(LocalAutotuneCache):
99+
"""
100+
Stricter implementation of the local autotune cache, which takes into
101+
account library level code changes such as Triton, Helion or PyTorch.
102+
"""
103+
104+
def _generate_key(self) -> StrictAutotuneCacheKey:
105+
loose_key = super()._generate_key()
106+
return StrictAutotuneCacheKey(**vars(loose_key))

helion/runtime/config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from collections.abc import Iterator
44
from collections.abc import Mapping
55
import json
6+
import os
67
from pathlib import Path
78
from typing import Literal
89
from typing import cast
10+
import uuid
911

1012
from ..autotuner.config_spec import DEFAULT_NUM_STAGES
1113
from ..autotuner.config_spec import DEFAULT_NUM_WARPS
@@ -118,7 +120,13 @@ def from_json(cls, json_str: str) -> Config:
118120

119121
def save(self, path: str | Path) -> None:
120122
"""Save the config to a JSON file."""
121-
Path(path).write_text(self.to_json())
123+
# Write to temp dir and rename to make the operation atomic
124+
# in case we are in a multithreaded environment
125+
Path(path).parent.mkdir(parents=True, exist_ok=True)
126+
127+
tmp = Path(path).parent / f"tmp.{uuid.uuid4()!s}"
128+
tmp.write_text(self.to_json())
129+
os.rename(str(tmp), str(path))
122130

123131
@classmethod
124132
def load(cls, path: str | Path) -> Config:

helion/runtime/kernel.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from torch._guards import Source
4646

4747
from ..autotuner import ConfigSpec
48+
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey
4849

4950
ConfigLike = Config | dict[str, object]
5051

@@ -53,12 +54,6 @@
5354
CompiledConfig = Callable[..., _R]
5455

5556

56-
@dataclasses.dataclass(frozen=True)
57-
class BoundKernelInMemoryCacheKey:
58-
specialization_key: tuple[Hashable, ...]
59-
extra_results: tuple[Hashable, ...]
60-
61-
6257
class Kernel(Generic[_R]):
6358
def __init__(
6459
self,
@@ -114,6 +109,8 @@ def __init__(
114109
def _get_bound_kernel_cache_key(
115110
self, args: tuple[object, ...], signature: tuple[Hashable, ...]
116111
) -> BoundKernelInMemoryCacheKey | None:
112+
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey
113+
117114
extra_fns = self._specialize_extra.get(signature)
118115
if extra_fns is not None:
119116
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns])
@@ -126,6 +123,8 @@ def _create_bound_kernel_cache_key(
126123
args: tuple[object, ...],
127124
signature: tuple[Hashable, ...],
128125
) -> BoundKernelInMemoryCacheKey:
126+
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey
127+
129128
self._specialize_extra[signature] = extra_fns = bound_kernel._specialize_extra()
130129
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns])
131130
return BoundKernelInMemoryCacheKey(signature, extra_results)
@@ -458,12 +457,18 @@ def autotune(
458457
self.settings.check_autotuning_disabled()
459458

460459
from ..autotuner import DifferentialEvolutionSearch
460+
from ..autotuner import LocalAutotuneCache
461461

462-
config = DifferentialEvolutionSearch(
462+
config = LocalAutotuneCache(
463463
self,
464464
args,
465-
**kwargs, # pyright: ignore[reportArgumentType]
465+
DifferentialEvolutionSearch(
466+
self,
467+
args,
468+
**kwargs, # pyright: ignore[reportArgumentType]
469+
),
466470
).autotune()
471+
467472
self.set_config(config)
468473
return config
469474

0 commit comments

Comments
 (0)