Skip to content

Limit tensor block numel to triton's requirements #485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions helion/autotuner/block_id_sequence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import dataclasses
import math
from typing import TYPE_CHECKING
from typing import Callable
from typing import MutableSequence
Expand Down Expand Up @@ -185,6 +186,10 @@ def _normalize(
f"(one for each tiled dimension), got {len(values)}. "
f"Did you forget to specify block sizes for all your hl.tile() dimensions?"
) from None
if name == "block_sizes" and math.prod(values) > 1048576:
raise InvalidConfig(
"Triton does not allow for tensor numel greater than 1048576"
)
Comment on lines +189 to +192
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all block sizes should be included in this count, only the ones in the top level loop. We might need to tag some block sizes as coming from the grid and only count those.

for i, spec in enumerate(self._data):
values[i] = spec._normalize(f"config[{name}][{i}]", values[i])
return values
Expand Down
19 changes: 19 additions & 0 deletions test/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ def fn(x: torch.Tensor) -> torch.Tensor:
(torch.randn(4, 8, 16, device=DEVICE),),
)

def test_invalid_tensor_numel(self):
"""Test that InvalidConfig shows helpful message for invalid block sizes."""

@helion.kernel(config=helion.Config(block_sizes=[2048, 1024]))
def fn(x: torch.Tensor) -> torch.Tensor:
out = torch.zeros_like(x)
for tile_m, tile_n in hl.tile(x.size()):
out[tile_m, tile_n] = x[tile_m, tile_n]
return out

with self.assertRaisesRegex(
helion.exc.InvalidConfig,
"Triton does not allow for tensor numel greater than 1048576",
):
code_and_output(
fn,
(torch.randn(2048, 2048, device=DEVICE),),
)

def test_rank_mismatch_indexing(self):
"""Test that RankMismatch shows tensor shapes in indexing errors."""

Expand Down
Loading