Skip to content

Commit 8295250

Browse files
authored
Use autotuner's BoundKernel in caching (#388)
1 parent 3c3c64a commit 8295250

File tree

4 files changed

+8
-25
lines changed

4 files changed

+8
-25
lines changed

helion/autotuner/base_cache.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import os
99
from typing import TYPE_CHECKING
1010
from typing import Hashable
11-
from typing import Sequence
1211

1312
from torch._inductor.codecache import build_code_hash
1413
from torch._inductor.codecache import torch_key
@@ -18,7 +17,6 @@
1817

1918
if TYPE_CHECKING:
2019
from ..runtime.config import Config
21-
from ..runtime.kernel import BoundKernel
2220
from .base_search import BaseSearch
2321

2422
log: logging.Logger = logging.getLogger(__name__)
@@ -114,12 +112,10 @@ class AutotuneCacheBase(abc.ABC):
114112
provide implementations for get and put methods.
115113
"""
116114

117-
def __init__(
118-
self, kernel: BoundKernel, args: Sequence[object], autotuner: BaseSearch
119-
) -> None:
115+
def __init__(self, autotuner: BaseSearch) -> None:
120116
self.autotuner = autotuner
121-
self.kernel = kernel
122-
self.args = args
117+
self.kernel = self.autotuner.kernel
118+
self.args = self.autotuner.args
123119

124120
@abc.abstractmethod
125121
def get(self) -> Config | None:

helion/autotuner/local_cache.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pathlib import Path
88
import textwrap
99
from typing import TYPE_CHECKING
10-
from typing import Sequence
1110

1211
import torch
1312
from torch._inductor.runtime.cache_dir_utils import (
@@ -20,7 +19,6 @@
2019
from .base_cache import StrictAutotuneCacheKey
2120

2221
if TYPE_CHECKING:
23-
from ..runtime.kernel import BoundKernel
2422
from .base_search import BaseSearch
2523

2624
log: logging.Logger = logging.getLogger(__name__)
@@ -38,10 +36,8 @@ class LocalAutotuneCache(AutotuneCacheBase):
3836
PyTorch. Use StrictLocalAutotuneCache to consider these properties.
3937
"""
4038

41-
def __init__(
42-
self, kernel: BoundKernel, args: Sequence[object], autotuner: BaseSearch
43-
) -> None:
44-
super().__init__(kernel, args, autotuner)
39+
def __init__(self, autotuner: BaseSearch) -> None:
40+
super().__init__(autotuner)
4541
self.key = self._generate_key()
4642

4743
def _generate_key(self) -> LooseAutotuneCacheKey:

helion/runtime/kernel.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,8 +460,6 @@ def autotune(
460460
from ..autotuner import LocalAutotuneCache
461461

462462
config = LocalAutotuneCache(
463-
self,
464-
args,
465463
DifferentialEvolutionSearch(
466464
self,
467465
args,

test/test_cache.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,8 @@ def test_basic(self):
2828
b = torch.randn(16, device=DEVICE, dtype=torch.float16)
2929
args_b = (b, b)
3030

31-
# TODO(oulgen): Using a custom autotuner is very verbose, requires passing args 3 times etc
3231
bound_kernel = basic_kernels.add.bind(args_a)
33-
config = StrictLocalAutotuneCache(
34-
bound_kernel, args_a, BasicSearch(bound_kernel, args_a)
35-
).autotune()
32+
config = StrictLocalAutotuneCache(BasicSearch(bound_kernel, args_a)).autotune()
3633
bound_kernel.set_config(config)
3734
result = bound_kernel(*args_a)
3835
torch.testing.assert_close(result, a + a)
@@ -44,9 +41,7 @@ def test_basic(self):
4441
basic_kernels.add.reset()
4542

4643
bound_kernel = basic_kernels.add.bind(args_a)
47-
config = StrictLocalAutotuneCache(
48-
bound_kernel, args_a, BasicSearch(bound_kernel, args_a)
49-
).autotune()
44+
config = StrictLocalAutotuneCache(BasicSearch(bound_kernel, args_a)).autotune()
5045
bound_kernel.set_config(config)
5146
result = bound_kernel(*args_a)
5247
torch.testing.assert_close(result, a + a)
@@ -58,9 +53,7 @@ def test_basic(self):
5853
basic_kernels.add.reset()
5954

6055
bound_kernel = basic_kernels.add.bind(args_b)
61-
config = StrictLocalAutotuneCache(
62-
bound_kernel, args_b, BasicSearch(bound_kernel, args_b)
63-
).autotune()
56+
config = StrictLocalAutotuneCache(BasicSearch(bound_kernel, args_b)).autotune()
6457
bound_kernel.set_config(config)
6558
result = bound_kernel(*args_b)
6659
torch.testing.assert_close(result, b + b)

0 commit comments

Comments
 (0)