Skip to content

Commit fa08371

Browse files
authored
[BC breaking] Simplify block size configs (#127)
This removes the prior nested config structure where you would have: block_sizes=[[8, 8], 8] for nested loops, and block_sizes=[64, 8] for a flattened loop. Instead, you would now have: block_sizes=[8, 8, 8], flatten_loops=[False] for nested loops, and block_sizes=[8, 8, 8] flatten_loops=[True] for a flattened loop This makes config structures more predictable and easier to work with.
1 parent c145af0 commit fa08371

23 files changed

+389
-391
lines changed

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,16 +179,18 @@ and configurations directly from your code.
179179

180180
Helion configurations include the following options:
181181

182-
* **block\_sizes** (`list[int | list[int]]`):
183-
Controls tile sizes corresponding to each `hl.tile` invocation in the
184-
kernel. For tiles with two or more dimensions, you can use either an
185-
integer to flatten the iteration space into a single dimension or a list
186-
of integers for multi-dimensional tiling.
182+
* **block\_sizes** (`list[int]`):
183+
Controls tile sizes corresponding to each dimension passed `hl.tile` or call
184+
to `hl.register_block_size` in the kernel.
187185

188186
* **loop\_orders** (`list[list[int]]`):
189187
Contains one entry per `hl.tile` call with two or more dimensions,
190188
allowing you to permute the iteration order of the tiles.
191189

190+
* **flatten_loops** (`list[bool]`):
191+
Contains one entry per `hl.tile` call with two or more dimensions,
192+
allowing you to flatten the iteration space into a single dimension.
193+
192194
* **reduction\_loops** (`list[int | None]`):
193195
Contains one entry per reduction dimension (see
194196
`examples/softmax.py`). Using `None` triggers a persistent reduction,

examples/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
@helion.kernel(
1313
config=helion.Config(
1414
# This config was autotuned on a 3090, it won't be fast for other architectures
15-
block_sizes=[[32], [16]],
15+
block_sizes=[32, 16],
1616
num_warps=1,
1717
num_stages=2,
1818
indexing="block_ptr",

examples/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
@helion.kernel(
1010
config=helion.Config(
11-
block_size=[512, 32], loop_order=[0, 1], num_warps=8, indexing="block_ptr"
11+
block_sizes=[512, 32], loop_order=[0, 1], num_warps=8, indexing="block_ptr"
1212
)
1313
)
1414
def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:

examples/jagged_dense_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
@helion.kernel(
2323
config=helion.Config(
24-
block_sizes=[[1], [512], [512]], num_warps=8, num_stages=4, indexing="block_ptr"
24+
block_sizes=[1, 512, 512], num_warps=8, num_stages=4, indexing="block_ptr"
2525
)
2626
)
2727
def jagged_dense_add_2d(

examples/long_sum.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def baseline_sum(x: torch.Tensor) -> torch.Tensor:
1313
# Naive Reduction: Load the entire reduction dim at once, and reduce in reg.
1414
@helion.kernel(
1515
config=helion.Config(
16-
block_sizes=[[1]],
16+
block_sizes=[1],
1717
reduction_loops=[None],
1818
num_warps=32,
1919
num_stages=4,
@@ -32,7 +32,7 @@ def longsum(x: torch.Tensor) -> torch.Tensor:
3232
# Looped reduction
3333
@helion.kernel(
3434
config=helion.Config(
35-
block_sizes=[[1]],
35+
block_sizes=[1],
3636
reduction_loops=[
3737
32768
3838
], # [None] for naive reduction, [tile_size] for looped reduction
@@ -53,7 +53,7 @@ def longsum_w_red_loop(x: torch.Tensor) -> torch.Tensor:
5353
# This generates the same code as above, but manually implements looped reduction.
5454
@helion.kernel(
5555
config=helion.Config(
56-
block_sizes=[[32768], [1]], num_warps=16, num_stages=5, indexing="pointer"
56+
block_sizes=[32768, 1], num_warps=16, num_stages=5, indexing="pointer"
5757
)
5858
)
5959
def longsum_manual(x: torch.Tensor) -> torch.Tensor:

examples/template_via_closure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
@helion.kernel(
1616
# This was tuned on a 3090 and likely isn't optimal for other GPUs
1717
config=helion.Config(
18-
block_sizes=[[64, 64], [16]],
18+
block_sizes=[64, 64, 16],
1919
loop_orders=[[0, 1]],
2020
num_warps=2,
2121
num_stages=3,

helion/_compiler/compile_environment.py

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import collections
4+
import contextlib
45
import dataclasses
56
import threading
67
import types
@@ -86,9 +87,8 @@ def finalize_config_spec(self) -> None:
8687
from .tile_strategy import FlattenedTileStrategy
8788

8889
for shape in self.kernel_tensor_sizes:
89-
FlattenedTileStrategy.update_allow_flattened(
90-
self.config_spec.block_size_specs, shape
91-
)
90+
FlattenedTileStrategy.update_allow_flattened(shape)
91+
self.config_spec._remove_duplicates()
9292

9393
def allocate_block_size(
9494
self,
@@ -343,68 +343,66 @@ def mark_alternate_size(self, size: torch.SymInt | int | None) -> None:
343343
if isinstance(self.size, AutoSize):
344344
# The block size was created by hl.register_block_size, and we didn't know the size yet.
345345
self.size = size
346-
if isinstance(size, (int, torch.SymInt)) and isinstance(
347-
source := self.block_size_source, LoopSpecBlockSizeSource
348-
):
349-
# update the size hint now that we know the size
346+
if size is not None:
350347
env = CompileEnvironment.current()
351-
env.config_spec.block_size_specs[source.loop_spec].update_hint(
352-
source.dim, env.size_hint(size)
353-
)
348+
with contextlib.suppress(KeyError):
349+
# update the size hint now that we know the size
350+
env.config_spec.block_sizes.block_id_lookup(
351+
self.block_size_idx
352+
).update_hint(env.size_hint(size))
354353
elif size is None or self.size is None or self.size != size:
355354
self.size = None
356355

357356
def symbol(self) -> sympy.Symbol:
358357
return self.var._sympy_()
359358

360359
def from_config(self, config: Config) -> int | torch.SymInt | None:
361-
return self.block_size_source.from_config(config)
360+
return self.block_size_source.from_config(config, self.block_size_idx)
362361

363362
def from_config_assert(self, config: Config) -> int | torch.SymInt:
364363
val = self.from_config(config)
365364
assert val is not None
366365
return val
367366

368367
def is_flattened(self, config: Config) -> bool:
369-
return self.block_size_source.is_flattened(config)
368+
spec = CompileEnvironment.current().config_spec
369+
return spec.flatten_loops.config_get(
370+
config.flatten_loops, self.block_size_idx, False
371+
)
370372

371373
def is_grid(self) -> bool:
372374
return self.block_size_source.is_grid()
373375

374376
def update_min_block(self, value: int, *, allow_flattened: bool = True) -> None:
375-
return self.block_size_source.update_min_block(
376-
value, allow_flattened=allow_flattened
377-
)
377+
spec = CompileEnvironment.current().config_spec
378+
if not allow_flattened:
379+
spec.flatten_loops.disable_block_id(self.block_size_idx)
380+
with contextlib.suppress(KeyError):
381+
spec.block_sizes.block_id_lookup(self.block_size_idx).update_min(value)
378382

379383

380384
class BlockSizeSource:
381-
def from_config(self, config: Config) -> int | torch.SymInt | None:
385+
def from_config(self, config: Config, block_id: int) -> int | torch.SymInt | None:
382386
raise NotImplementedError
383387

384-
def is_flattened(self, config: Config) -> bool:
385-
return False
386-
387388
def is_grid(self) -> bool:
388389
return False
389390

390391
def l2_grouping(self, config: Config) -> int:
391392
return 1
392393

393-
def update_min_block(self, value: int, *, allow_flattened: bool = True) -> None:
394-
return None
395-
396394

397395
@dataclasses.dataclass
398396
class FixedBlockSizeSource(BlockSizeSource):
399397
value: int | torch.SymInt
400398

401-
def from_config(self, config: Config) -> int | torch.SymInt:
399+
def from_config(self, config: Config, block_id: int) -> int | torch.SymInt:
402400
return self.value
403401

404402

405403
@dataclasses.dataclass
406404
class GridBlockSizeSource(BlockSizeSource):
407-
def from_config(self, config: Config) -> int:
405+
def from_config(self, config: Config, block_id: int) -> int:
408406
raise NotImplementedError
409407

410408
def is_grid(self) -> bool:
@@ -413,33 +411,18 @@ def is_grid(self) -> bool:
413411

414412
@dataclasses.dataclass
415413
class LoopSpecBlockSizeSource(BlockSizeSource):
416-
loop_spec: int
417-
dim: int
418-
419-
def from_config(self, config: Config) -> int:
420-
value = config.block_sizes[self.loop_spec]
421-
if isinstance(value, int):
422-
assert self.dim == 0
423-
return value
424-
return value[self.dim]
425-
426-
def is_flattened(self, config: Config) -> bool:
427-
return isinstance(config.block_sizes[self.loop_spec], int)
428-
429-
def update_min_block(self, value: int, *, allow_flattened: bool = True) -> None:
430-
"""
431-
Update the minimum block size for the given block index, only increases the minimum size.
432-
"""
433-
spec = CompileEnvironment.current().config_spec.block_size_specs[self.loop_spec]
434-
spec.update_min(self.dim, value)
435-
spec.allow_flattened = spec.allow_flattened and allow_flattened
414+
def from_config(self, config: Config, block_id: int) -> int:
415+
index = CompileEnvironment.current().config_spec.block_sizes.block_id_to_index(
416+
block_id
417+
)
418+
return config.block_sizes[index]
436419

437420

438421
@dataclasses.dataclass
439422
class ReductionLoopBlockSizeSource(BlockSizeSource):
440423
reduction_loop: int
441424

442-
def from_config(self, config: Config) -> int | None:
425+
def from_config(self, config: Config, block_id: int) -> int | None:
443426
return config.reduction_loops[self.reduction_loop]
444427

445428

helion/_compiler/tile_dispatch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
import collections
4+
import functools
5+
import operator
46
from typing import TYPE_CHECKING
57

68
from helion._compiler.compile_environment import CompileEnvironment
@@ -74,10 +76,13 @@ def _add_loop_strategy(
7476
loop_order=loop_order,
7577
)
7678
elif block_size_infos[0].is_flattened(config):
79+
block_size = functools.reduce(
80+
operator.mul, [bs.from_config_assert(config) for bs in block_size_infos]
81+
)
7782
strategy: TileStrategy = FlattenedTileStrategy(
7883
fn,
7984
block_indices,
80-
block_size=block_size_infos[0].from_config_assert(config),
85+
block_size=block_size,
8186
loop_order=loop_order,
8287
)
8388
else:

helion/_compiler/tile_strategy.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from .ast_extension import expr_from_string
1919
from .ast_extension import statement_from_string
2020
from .compile_environment import CompileEnvironment
21-
from .compile_environment import LoopSpecBlockSizeSource
2221
from .compile_environment import _to_sympy
2322
from .host_function import HostFunction
2423
from .program_id import GridProgramIDs
@@ -31,7 +30,6 @@
3130
if TYPE_CHECKING:
3231
from collections.abc import Sequence
3332

34-
from ..autotuner.config_spec import BlockSizeSpec
3533
from .device_function import DeviceFunction
3634
from .inductor_lowering import CodegenState
3735

@@ -302,43 +300,25 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
302300
)
303301

304302
@classmethod
305-
def update_allow_flattened(
306-
cls, specs: list[BlockSizeSpec], shape: Sequence[sympy.Expr]
307-
) -> None:
303+
def update_allow_flattened(cls, shape: Sequence[sympy.Expr]) -> None:
308304
used_indices = {}
309305
for i, x in enumerate(shape):
310306
block_idx = cls.get_block_index(x)
311307
if block_idx is not None:
312-
if block_idx in used_indices:
313-
# multiple usages of the same block size??? bail out
314-
for spec in specs:
315-
spec.allow_flattened = False
316-
return
317308
used_indices[block_idx] = i
318-
env = CompileEnvironment.current()
319-
for spec_idx, group in itertools.groupby(
320-
[
321-
bs
322-
for bs in env.block_sizes
323-
if isinstance(bs.block_size_source, LoopSpecBlockSizeSource)
324-
],
325-
key=lambda x: x.block_size_source.loop_spec,
326-
):
327-
spec = specs[spec_idx]
328-
if not spec.allow_flattened:
329-
continue
330-
block_indices = [bs.block_size_idx for bs in group]
331-
if len(block_indices) == 1 or not (
309+
flatten_loops = CompileEnvironment.current().config_spec.flatten_loops
310+
for spec in [*flatten_loops]:
311+
block_indices = spec.block_ids
312+
if not (
332313
all(x in used_indices for x in block_indices)
333314
or all(x not in used_indices for x in block_indices)
334315
):
335-
# A shape must use all or none of the block indices in the group
336-
spec.allow_flattened = False
316+
flatten_loops.disable_block_id(block_indices[0])
337317
continue
338318
for i, j in itertools.pairwise(block_indices):
339319
if i in used_indices and used_indices[i] + 1 != used_indices[j]:
340320
# The block indices must be contiguous
341-
spec.allow_flattened = False
321+
flatten_loops.disable_block_id(block_indices[0])
342322
break
343323

344324
def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]:

helion/_compiler/type_propagation.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -964,25 +964,17 @@ def proxy(self) -> object:
964964

965965
@staticmethod
966966
def allocate(
967-
numels: list[int | torch.SymInt | AutoSize | None], origin: Origin
968-
) -> list[TileIndexType]:
967+
numel: int | torch.SymInt | AutoSize | None, origin: Origin
968+
) -> TileIndexType:
969969
env = CompileEnvironment.current()
970-
spec_id = len(env.config_spec.block_size_specs)
971-
env.config_spec.block_size_specs.append(
970+
block_id = env.allocate_block_size(numel, source=LoopSpecBlockSizeSource())
971+
env.config_spec.block_sizes.append(
972972
BlockSizeSpec(
973-
size_hints=[*map(_get_hint, numels)],
974-
allow_flattened=len(numels) > 1,
973+
block_id=block_id,
974+
size_hint=_get_hint(numel),
975975
)
976976
)
977-
return [
978-
TileIndexType(
979-
origin,
980-
env.allocate_block_size(
981-
x, source=LoopSpecBlockSizeSource(spec_id, dim)
982-
),
983-
)
984-
for dim, x in enumerate(numels)
985-
]
977+
return TileIndexType(origin, block_id)
986978

987979
@staticmethod
988980
def allocate_fixed(

0 commit comments

Comments
 (0)