Skip to content

Commit f5594ca

Browse files
authored
Add autotuner_fn argument to @helion.kernel for custom autotuners (#394)
1 parent bb8e31d commit f5594ca

File tree

5 files changed

+52
-32
lines changed

5 files changed

+52
-32
lines changed

helion/autotuner/base_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch._inductor.codecache import torch_key
1414

1515
from .._utils import counters
16+
from .base_search import BaseAutotuner
1617

1718
if TYPE_CHECKING:
1819
from ..runtime.config import Config
@@ -106,7 +107,7 @@ class StrictAutotuneCacheKey(LooseAutotuneCacheKey):
106107
triton_key: str = dataclasses.field(default_factory=triton_key_wrapper)
107108

108109

109-
class AutotuneCacheBase(abc.ABC):
110+
class AutotuneCacheBase(BaseAutotuner, abc.ABC):
110111
"""
111112
Abstract base class that all autotune caches need to implement.
112113
Any user defined cache will need to extend this class, and

helion/autotuner/base_search.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import abc
34
import collections
45
import contextlib
56
import dataclasses
@@ -56,7 +57,17 @@
5657
)
5758

5859

59-
class BaseSearch:
60+
class BaseAutotuner(abc.ABC):
61+
"""
62+
Abstract base class for all autotuners and classes that wrap autotuners, like caching.
63+
"""
64+
65+
@abc.abstractmethod
66+
def autotune(self) -> Config:
67+
raise NotImplementedError
68+
69+
70+
class BaseSearch(BaseAutotuner):
6071
"""
6172
Base class for search algorithms. This class defines the interface and utilities for all
6273
search algorithms.

helion/runtime/kernel.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -467,17 +467,7 @@ def autotune(
467467
config = FiniteSearch(self, args, self.configs).autotune()
468468
else:
469469
self.settings.check_autotuning_disabled()
470-
471-
from ..autotuner import DifferentialEvolutionSearch
472-
from ..autotuner import LocalAutotuneCache
473-
474-
config = LocalAutotuneCache(
475-
DifferentialEvolutionSearch(
476-
self,
477-
args,
478-
**kwargs, # pyright: ignore[reportArgumentType]
479-
),
480-
).autotune()
470+
config = self.settings.autotuner_fn(self, args, **kwargs).autotune()
481471

482472
self.set_config(config)
483473
return config

helion/runtime/settings.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import TYPE_CHECKING
99
from typing import Literal
1010
from typing import Protocol
11+
from typing import Sequence
1112
from typing import cast
1213

1314
import torch
@@ -19,9 +20,17 @@
1920
if TYPE_CHECKING:
2021
from contextlib import AbstractContextManager
2122

23+
from ..autotuner.base_search import BaseAutotuner
24+
from .kernel import BoundKernel
25+
2226
class _TLS(Protocol):
2327
default_settings: Settings | None
2428

29+
class AutotunerFunction(Protocol):
30+
def __call__(
31+
self, bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
32+
) -> BaseAutotuner: ...
33+
2534

2635
_tls: _TLS = cast("_TLS", threading.local())
2736

@@ -50,6 +59,15 @@ def __exit__(self, *args: object) -> None:
5059
return _RestoreContext()
5160

5261

62+
def default_autotuner_fn(
63+
bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
64+
) -> BaseAutotuner:
65+
from ..autotuner import DifferentialEvolutionSearch
66+
from ..autotuner import LocalAutotuneCache
67+
68+
return LocalAutotuneCache(DifferentialEvolutionSearch(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]
69+
70+
5371
@dataclasses.dataclass
5472
class _Settings:
5573
# see __slots__ below for the doc strings that show up in help(Settings)
@@ -76,6 +94,7 @@ class _Settings:
7694
ref_mode: RefMode = (
7795
RefMode.EAGER if os.environ.get("HELION_INTERPRET", "") == "1" else RefMode.OFF
7896
)
97+
autotuner_fn: AutotunerFunction = default_autotuner_fn
7998

8099

81100
class Settings(_Settings):
@@ -97,6 +116,7 @@ class Settings(_Settings):
97116
"force_autotune": "If True, force autotuning even if a config is provided.",
98117
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",
99118
"ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.",
119+
"autotuner_fn": "Function to create an autotuner",
100120
}
101121
assert __slots__.keys() == {field.name for field in dataclasses.fields(_Settings)}
102122

test/test_cache.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
from __future__ import annotations
22

3-
from pathlib import Path
43
import unittest
54

65
import torch
76

7+
import helion
88
from helion._testing import DEVICE
99
from helion._testing import RefEagerTestDisabled
1010
from helion._testing import TestCase
11-
from helion._testing import import_path
1211
from helion._utils import counters
1312
from helion.autotuner import StrictLocalAutotuneCache
1413
from helion.autotuner.base_search import BaseSearch
15-
16-
datadir = Path(__file__).parent / "data"
17-
basic_kernels = import_path(datadir / "basic_kernels.py")
14+
import helion.language as hl
1815

1916

2017
class BasicSearch(BaseSearch):
@@ -24,39 +21,40 @@ def autotune(self):
2421

2522
class TestCache(RefEagerTestDisabled, TestCase):
2623
def test_basic(self):
24+
@helion.kernel(
25+
autotuner_fn=lambda k, a: StrictLocalAutotuneCache(BasicSearch(k, a))
26+
)
27+
def add(x, y):
28+
x, y = torch.broadcast_tensors(x, y)
29+
out = torch.empty_like(x)
30+
for tile in hl.tile(out.size()):
31+
out[tile] = x[tile] + y[tile]
32+
return out
33+
2734
a = torch.randn(16, device=DEVICE, dtype=torch.bfloat16)
2835
args_a = (a, a)
2936
b = torch.randn(16, device=DEVICE, dtype=torch.float16)
3037
args_b = (b, b)
3138

32-
bound_kernel = basic_kernels.add.bind(args_a)
33-
config = StrictLocalAutotuneCache(BasicSearch(bound_kernel, args_a)).autotune()
34-
bound_kernel.set_config(config)
35-
result = bound_kernel(*args_a)
39+
result = add(*args_a)
3640
torch.testing.assert_close(result, a + a)
3741

3842
self.assertEqual(counters["autotune"]["cache_miss"], 1)
3943
self.assertEqual(counters["autotune"]["cache_hit"], 0)
4044
self.assertEqual(counters["autotune"]["cache_put"], 1)
4145

42-
basic_kernels.add.reset()
46+
add.reset()
4347

44-
bound_kernel = basic_kernels.add.bind(args_a)
45-
config = StrictLocalAutotuneCache(BasicSearch(bound_kernel, args_a)).autotune()
46-
bound_kernel.set_config(config)
47-
result = bound_kernel(*args_a)
48+
result = add(*args_a)
4849
torch.testing.assert_close(result, a + a)
4950

5051
self.assertEqual(counters["autotune"]["cache_miss"], 1)
5152
self.assertEqual(counters["autotune"]["cache_hit"], 1)
5253
self.assertEqual(counters["autotune"]["cache_put"], 1)
5354

54-
basic_kernels.add.reset()
55+
add.reset()
5556

56-
bound_kernel = basic_kernels.add.bind(args_b)
57-
config = StrictLocalAutotuneCache(BasicSearch(bound_kernel, args_b)).autotune()
58-
bound_kernel.set_config(config)
59-
result = bound_kernel(*args_b)
57+
result = add(*args_b)
6058
torch.testing.assert_close(result, b + b)
6159

6260
self.assertEqual(counters["autotune"]["cache_miss"], 2)

0 commit comments

Comments
 (0)