diff --git a/helion/autotuner/block_id_sequence.py b/helion/autotuner/block_id_sequence.py index babc82e4..357104d1 100644 --- a/helion/autotuner/block_id_sequence.py +++ b/helion/autotuner/block_id_sequence.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +import math from typing import TYPE_CHECKING from typing import Callable from typing import MutableSequence @@ -185,6 +186,10 @@ def _normalize( f"(one for each tiled dimension), got {len(values)}. " f"Did you forget to specify block sizes for all your hl.tile() dimensions?" ) from None + if name == "block_sizes" and math.prod(values) > 1048576: + raise InvalidConfig( + "Triton does not allow for tensor numel greater than 1048576" + ) for i, spec in enumerate(self._data): values[i] = spec._normalize(f"config[{name}][{i}]", values[i]) return values diff --git a/test/test_errors.py b/test/test_errors.py index f42bcb67..f842e93e 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -60,6 +60,25 @@ def fn(x: torch.Tensor) -> torch.Tensor: (torch.randn(4, 8, 16, device=DEVICE),), ) + def test_invalid_tensor_numel(self): + """Test that InvalidConfig shows helpful message for invalid block sizes.""" + + @helion.kernel(config=helion.Config(block_sizes=[2048, 1024])) + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.zeros_like(x) + for tile_m, tile_n in hl.tile(x.size()): + out[tile_m, tile_n] = x[tile_m, tile_n] + return out + + with self.assertRaisesRegex( + helion.exc.InvalidConfig, + "Triton does not allow for tensor numel greater than 1048576", + ): + code_and_output( + fn, + (torch.randn(2048, 2048, device=DEVICE),), + ) + def test_rank_mismatch_indexing(self): """Test that RankMismatch shows tensor shapes in indexing errors."""