Skip to content

Commit be7d80a

Browse files
committed
Limit tensor block numel to triton's requirements
Fixes #456
1 parent 8819331 commit be7d80a

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

helion/autotuner/block_id_sequence.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import dataclasses
4+
import math
45
from typing import TYPE_CHECKING
56
from typing import Callable
67
from typing import MutableSequence
@@ -185,6 +186,10 @@ def _normalize(
185186
f"(one for each tiled dimension), got {len(values)}. "
186187
f"Did you forget to specify block sizes for all your hl.tile() dimensions?"
187188
) from None
189+
if name == "block_sizes" and math.prod(values) > 1048576:
190+
raise InvalidConfig(
191+
"Triton does not allow for tensor numel greater than 1048576"
192+
)
188193
for i, spec in enumerate(self._data):
189194
values[i] = spec._normalize(f"config[{name}][{i}]", values[i])
190195
return values

test/test_errors.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ def fn(x: torch.Tensor) -> torch.Tensor:
6060
(torch.randn(4, 8, 16, device=DEVICE),),
6161
)
6262

63+
def test_invalid_tensor_numel(self):
64+
"""Test that InvalidConfig shows helpful message for invalid block sizes."""
65+
66+
@helion.kernel(config=helion.Config(block_sizes=[2048, 1024]))
67+
def fn(x: torch.Tensor) -> torch.Tensor:
68+
out = torch.zeros_like(x)
69+
for tile_m, tile_n in hl.tile(x.size()):
70+
out[tile_m, tile_n] = x[tile_m, tile_n]
71+
return out
72+
73+
with self.assertRaisesRegex(
74+
helion.exc.InvalidConfig,
75+
"Triton does not allow for tensor numel greater than 1048576",
76+
):
77+
code_and_output(
78+
fn,
79+
(torch.randn(2048, 2048, device=DEVICE),),
80+
)
81+
6382
def test_rank_mismatch_indexing(self):
6483
"""Test that RankMismatch shows tensor shapes in indexing errors."""
6584

0 commit comments

Comments
 (0)