Skip to content

Commit 78e4663

Browse files
authored
Refactor BoundKernel in memory caching (#351)
1 parent 2249441 commit 78e4663

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

helion/runtime/kernel.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@
5353
CompiledConfig = Callable[..., _R]
5454

5555

56+
@dataclasses.dataclass(frozen=True)
57+
class BoundKernelInMemoryCacheKey:
58+
specialization_key: tuple[Hashable, ...]
59+
extra_results: tuple[Hashable, ...]
60+
61+
5662
class Kernel(Generic[_R]):
5763
def __init__(
5864
self,
@@ -80,7 +86,7 @@ def __init__(
8086
Config(**c) if isinstance(c, dict) else c # pyright: ignore[reportArgumentType]
8187
for c in configs or []
8288
]
83-
self._bound_kernels: dict[Hashable, BoundKernel] = {}
89+
self._bound_kernels: dict[BoundKernelInMemoryCacheKey, BoundKernel] = {}
8490
self._specialize_extra: dict[
8591
Hashable, list[Callable[[Sequence[object]], Hashable]]
8692
] = {}
@@ -105,6 +111,25 @@ def __init__(
105111
else:
106112
self._annotations.append(ann)
107113

114+
def _get_bound_kernel_cache_key(
115+
self, args: tuple[object, ...], signature: tuple[Hashable, ...]
116+
) -> BoundKernelInMemoryCacheKey | None:
117+
extra_fns = self._specialize_extra.get(signature)
118+
if extra_fns is not None:
119+
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns])
120+
return BoundKernelInMemoryCacheKey(signature, extra_results)
121+
return None
122+
123+
def _create_bound_kernel_cache_key(
124+
self,
125+
bound_kernel: BoundKernel,
126+
args: tuple[object, ...],
127+
signature: tuple[Hashable, ...],
128+
) -> BoundKernelInMemoryCacheKey:
129+
self._specialize_extra[signature] = extra_fns = bound_kernel._specialize_extra()
130+
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns])
131+
return BoundKernelInMemoryCacheKey(signature, extra_results)
132+
108133
def bind(self, args: tuple[object, ...]) -> BoundKernel[_R]:
109134
"""
110135
Bind the given arguments to the Kernel and return a BoundKernel object.
@@ -119,28 +144,22 @@ def bind(self, args: tuple[object, ...]) -> BoundKernel[_R]:
119144
assert isinstance(args, list), "args must be a tuple or list"
120145
args = tuple(args)
121146
signature = self.specialization_key(args)
122-
extra_fns = self._specialize_extra.get(signature)
123-
if extra_fns is not None:
124-
extra_results: list[Hashable] = [s(args) for s in extra_fns]
125-
signature_extra = (*signature, *extra_results)
126-
bound_kernel = self._bound_kernels.get(signature_extra)
127-
else:
128-
signature_extra = None
129-
bound_kernel = None
147+
cache_key = self._get_bound_kernel_cache_key(args, signature)
148+
bound_kernel = (
149+
None if cache_key is None else self._bound_kernels.get(cache_key, None)
150+
)
130151
if bound_kernel is None:
131152
normalized_args: tuple[object, ...] = self.normalize_args(*args)
132153
if len(normalized_args) != len(args):
133154
# we had default args that needed to be applied
134155
bound_kernel = self.bind(normalized_args)
135156
else:
136157
bound_kernel = BoundKernel(self, args)
137-
if signature_extra is None:
138-
self._specialize_extra[signature] = extra_fns = (
139-
bound_kernel._specialize_extra()
158+
if cache_key is None:
159+
cache_key = self._create_bound_kernel_cache_key(
160+
bound_kernel, args, signature
140161
)
141-
extra_results = [s(args) for s in extra_fns]
142-
signature_extra = (*signature, *extra_results)
143-
self._bound_kernels[signature_extra] = bound_kernel
162+
self._bound_kernels[cache_key] = bound_kernel
144163
return bound_kernel
145164

146165
def specialization_key(self, args: Sequence[object]) -> tuple[Hashable, ...]:
@@ -608,16 +627,18 @@ def kernel(
608627

609628

610629
def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
630+
# NOTE: If a machine has two different gpu types on the same machine,
631+
# obj.device.type will incorrectly hit
611632
if fn.settings.static_shapes:
612633
return (
613634
obj.dtype,
614-
obj.device,
635+
obj.device.type,
615636
(*obj.size(),),
616637
(*obj.stride(),),
617638
)
618639
return (
619640
obj.dtype,
620-
obj.device,
641+
obj.device.type,
621642
# 0, 1, or >=2 specialization
622643
tuple([min(s, 2) for s in obj.size()]),
623644
)

0 commit comments

Comments
 (0)