Skip to content

Commit 4efd198

Browse files
authored
Move BlockIdSequence to its own file (#129)
1 parent c38fa55 commit 4efd198

File tree

2 files changed

+196
-180
lines changed

2 files changed

+196
-180
lines changed

helion/autotuner/block_id_sequence.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from __future__ import annotations
2+
3+
import dataclasses
4+
from typing import TYPE_CHECKING
5+
from typing import Callable
6+
from typing import MutableSequence
7+
from typing import TypeVar
8+
9+
from torch.fx.node import map_aggregate
10+
11+
from ..exc import InvalidConfig
12+
from .config_fragment import ConfigSpecFragment
13+
from .config_fragment import assert_integer_power_of_two
14+
15+
if TYPE_CHECKING:
16+
from . import ConfigSpec
17+
18+
_T = TypeVar("_T")
19+
_D = TypeVar("_D")
20+
21+
22+
@dataclasses.dataclass
23+
class _BlockIdItem:
24+
# the block_indices used in the IR
25+
block_ids: list[int]
26+
27+
@property
28+
def block_id(self) -> int:
29+
"""Return the first block_id for this item."""
30+
return self.block_ids[0]
31+
32+
def _fill_missing(self) -> object:
33+
"""Provide a value when not provided by the user."""
34+
raise NotImplementedError
35+
36+
def _normalize(self, name: str, value: object) -> object:
37+
"""Validate and normalize the value for this item."""
38+
raise NotImplementedError
39+
40+
def _fragment(self, base: ConfigSpec) -> ConfigSpecFragment:
41+
"""Return the fragment used for autotunging for this item."""
42+
raise NotImplementedError
43+
44+
def _flat_config(
45+
self, base: ConfigSpec, fn: Callable[[ConfigSpecFragment], object]
46+
) -> object:
47+
return fn(self._fragment(base))
48+
49+
50+
_BlockIdItemT = TypeVar("_BlockIdItemT", bound=_BlockIdItem)
51+
52+
53+
class BlockIdSequence(MutableSequence[_BlockIdItemT]):
54+
"""
55+
A sequence of _BlockIdItem subclasses that allows for efficient
56+
mapping from block_id to index in the sequence. A generic data
57+
structure used to store different types of configuration specs.
58+
"""
59+
60+
def __init__(self) -> None:
61+
self._data: list[_BlockIdItemT] = []
62+
self._block_id_to_index: dict[int, int] = {}
63+
64+
def __len__(self) -> int:
65+
return len(self._data)
66+
67+
def _reindex(self) -> None:
68+
"""Rebuild the mapping from block_id to index."""
69+
new_index = {}
70+
for i, item in enumerate(self._data):
71+
for block_id in item.block_ids:
72+
new_index[block_id] = i
73+
self._block_id_to_index = new_index
74+
75+
def __getitem__(self, index: int) -> _BlockIdItemT:
76+
return self._data[index]
77+
78+
def __setitem__(self, index: int, value: _BlockIdItemT) -> None:
79+
self._data[index] = value
80+
self._reindex() # could be faster, but uncommon case
81+
82+
def __delitem__(self, index: int) -> None:
83+
del self._data[index]
84+
self._reindex() # could be faster, but uncommon case
85+
86+
def clear(self) -> None:
87+
self._data.clear()
88+
self._block_id_to_index.clear()
89+
90+
def append(self, value: _BlockIdItemT) -> None:
91+
"""Append a new item to the end of the sequence."""
92+
index = len(self._data)
93+
self._data.append(value)
94+
for block_id in value.block_ids:
95+
self._block_id_to_index[block_id] = index
96+
97+
def insert(self, index: int, value: _BlockIdItemT) -> None:
98+
"""Insert a new item at the given index."""
99+
if index == len(self._data):
100+
self.append(value)
101+
return
102+
self._data.insert(index, value)
103+
self._reindex() # could be faster, but uncommon case
104+
105+
def block_id_to_index(self, block_id: int) -> int:
106+
"""Return the index of the block_id in the config."""
107+
return self._block_id_to_index[block_id]
108+
109+
def block_id_lookup(self, block_id: int) -> _BlockIdItemT:
110+
"""Return the index of the block_id in the config."""
111+
return self._data[self._block_id_to_index[block_id]]
112+
113+
def disable_block_id(self, block_id: int) -> None:
114+
"""Remove configuration choice for the given block_id."""
115+
self._data = [x for x in self._data if block_id not in x.block_ids]
116+
self._reindex()
117+
118+
def config_get(
119+
self, config: list[_T], block_id: int, default: _D = None
120+
) -> _T | _D:
121+
"""
122+
Get the config value for the given block_id, or return default if not found.
123+
"""
124+
index = self._block_id_to_index.get(block_id, None)
125+
if index is None:
126+
return default
127+
return config[index]
128+
129+
def _flat_config(
130+
self, base: ConfigSpec, fn: Callable[[ConfigSpecFragment], object]
131+
) -> list[object]:
132+
"""Map a flattened version of the config using the given function."""
133+
return [spec._flat_config(base, fn) for spec in self._data]
134+
135+
def _normalize(
136+
self, name: str, values: object, *, flatten: bool = False
137+
) -> list[object]:
138+
"""Validate and normalize the values for this config item."""
139+
if flatten:
140+
if values is None:
141+
values = ()
142+
new_values = []
143+
# pyre-ignore[6]
144+
map_aggregate(values, new_values.append)
145+
values = new_values
146+
elif not isinstance(values, (list, tuple, type(None))):
147+
raise InvalidConfig(
148+
f"Unexpected type for config[{name!r}], expected list or None, got {type(values).__name__}"
149+
)
150+
values = [*(values or ())]
151+
size = len(self)
152+
if len(values) > size:
153+
raise InvalidConfig(
154+
f"Too many values for config[{name!r}], expected {size}, got {len(values)}"
155+
)
156+
if len(values) < size:
157+
try:
158+
for spec in self._data[len(values) :]:
159+
values.append(spec._fill_missing())
160+
except NotImplementedError:
161+
raise InvalidConfig(
162+
f"Not enough values for config[{name!r}], expected {size}, got {len(values)}"
163+
) from None
164+
for i, spec in enumerate(self._data):
165+
values[i] = spec._normalize(f"config[{name}][{i}]", values[i])
166+
return values
167+
168+
def _remove_duplicates(self) -> None:
169+
new_specs = []
170+
for spec in self:
171+
other = self.block_id_lookup(spec.block_id)
172+
if other is spec:
173+
new_specs.append(spec)
174+
elif len(spec.block_ids) != len(other.block_ids):
175+
# this will cause invalid config errors with loop orders
176+
# remove them both
177+
self.disable_block_id(spec.block_id)
178+
self._remove_duplicates() # start over
179+
return
180+
if len(new_specs) != len(self):
181+
self._data = new_specs
182+
self._reindex()
183+
184+
185+
class _PowerOfTwoBlockIdItem(_BlockIdItem):
186+
def _normalize(self, name: str, value: object) -> int | None:
187+
try:
188+
return assert_integer_power_of_two(value)
189+
except InvalidConfig:
190+
raise InvalidConfig(
191+
f"{name} must be a power of two, got {value!r}"
192+
) from None

0 commit comments

Comments
 (0)