Skip to content

Fix tensor value assignment and partial indexing in Helion #439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: yf225/stack/54
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 113 additions & 3 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..runtime.config import Config
from .device_function import TensorDescriptorArg
from .inductor_lowering import CodegenState
from .tile_dispatch import TileStrategyDispatch

SymIntLike = torch.SymInt | int
ShapeLike = Sequence[SymIntLike]
Expand Down Expand Up @@ -61,6 +62,70 @@ def _normalize_negative_index(
return f"({state.codegen.device_function.user_sympy_expr(sympy_expr)})"


def _append_remaining_dimensions(
input_size: collections.deque,
output_size: list[int | torch.SymInt],
env: CompileEnvironment,
) -> None:
"""Append remaining dimensions from input to output for partial indexing.

Args:
input_size: Deque of remaining input dimensions
output_size: List to append output dimensions to
env: The compile environment
"""
while input_size:
size = input_size.popleft()
if size != 1:
rdim = env.allocate_reduction_dimension(size)
output_size.append(rdim.var)
else:
output_size.append(1)


def _handle_remaining_index_dimensions(
index_values: list[str],
mask_values: dict[str, None],
output_size: list[int | torch.SymInt],
output_idx: int,
fake_value: torch.Tensor,
state: CodegenState,
tile_strategy: TileStrategyDispatch,
env: CompileEnvironment,
dtype: str,
) -> int:
"""Handle remaining dimensions for partial indexing in SubscriptIndexing.create.

Args:
index_values: List to append index expressions to
mask_values: Dict to add mask expressions to
output_size: The output shape
output_idx: Current output index
fake_value: The tensor being indexed
state: The codegen state
tile_strategy: The tile strategy
env: The compile environment
dtype: The triton index type

Returns:
Updated output_idx
"""
while len(index_values) < fake_value.ndim:
expand = tile_strategy.expand_str(output_size, output_idx)
size = fake_value.size(len(index_values))
if size != 1:
rdim = env.allocate_reduction_dimension(size)
block_idx = rdim.block_id
index_var = state.codegen.index_var(block_idx)
index_values.append(f"({index_var}){expand}")
if mask := state.codegen.mask_var(block_idx):
mask_values.setdefault(f"({mask}){expand}")
else:
index_values.append(f"tl.zeros([1], {dtype}){expand}")
output_idx += 1
return output_idx


class IndexingStrategy:
def codegen_load(
self,
Expand Down Expand Up @@ -132,6 +197,32 @@ def codegen_store(
) -> ast.AST:
indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
name = state.device_function.tensor_arg(fake_tensor).name

# Check if value is a tensor load (Name node with id matching a tensor arg)
if isinstance(value, ast.Name) and hasattr(
state.device_function, "_tensor_args"
):
# Check if this name corresponds to a tensor argument
tensor = None
for t, tensor_arg in state.device_function._tensor_args.items():
if tensor_arg.name == value.id:
tensor = t
break

if tensor is not None:
# Get the shape of the slice we're storing to
output_shape = SubscriptIndexing.compute_shape(fake_tensor, subscript)
if len(output_shape) == 1 and tensor.ndim == 1:
# Load the entire 1D tensor
value_indexing = SubscriptIndexing.create(
state, tensor, [slice(None)], None
)
value = expr_from_string(
f"tl.load({value.id} + offset, mask)",
offset=value_indexing.index_expr,
mask=value_indexing.mask_expr,
)

return expr_from_string(
f"tl.store({name} + offset, value, mask)",
value=value,
Expand Down Expand Up @@ -503,7 +594,9 @@ def compute_shape(
) -> list[int | torch.SymInt]:
assert isinstance(tensor, torch.Tensor)
assert isinstance(index, (list, tuple)), index
input_size = collections.deque(tensor.size())
input_size: collections.deque[int | torch.SymInt] = collections.deque(
tensor.size()
)
output_size = []
env = CompileEnvironment.current()
for i, k in enumerate(index):
Expand Down Expand Up @@ -547,7 +640,8 @@ def compute_shape(
output_size.extend(k.size())
else:
raise exc.InvalidIndexingType(k)
assert len(input_size) == 0, "invalid subscript"
# For partial indexing, append remaining dimensions to output
_append_remaining_dimensions(input_size, output_size, env)
return output_size

@staticmethod
Expand Down Expand Up @@ -675,6 +769,20 @@ def create(
)
else:
raise exc.InvalidIndexingType(type(k))

# Handle remaining dimensions for partial indexing
output_idx = _handle_remaining_index_dimensions(
index_values,
mask_values,
output_size,
output_idx,
fake_value,
state,
tile_strategy,
env,
dtype,
)

assert len(output_size) == output_idx
assert len(index_values) == fake_value.ndim
index_expr = []
Expand Down Expand Up @@ -800,7 +908,9 @@ def is_supported(
if extra_mask is not None:
# TODO(jansel): support block_ptr with extra_mask
return False
input_sizes = collections.deque(fake_tensor.size())
input_sizes: collections.deque[int | torch.SymInt] = collections.deque(
fake_tensor.size()
)
for n, k in enumerate(index):
if k is None:
input_size = 1
Expand Down
22 changes: 17 additions & 5 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ast
import builtins
import collections
import contextlib
import dataclasses
import functools
Expand Down Expand Up @@ -485,12 +486,23 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
raise exc.OverpackedTile(k)
else:
raise exc.InvalidIndexingType(k)
if inputs_consumed != self.fake_value.ndim:
raise exc.RankMismatch(
self.fake_value.ndim,
inputs_consumed,
f"tensor shape: {tuple(self.fake_value.shape)}",
# Handle partial indexing - add remaining dimensions to output
if inputs_consumed < self.fake_value.ndim:
# Create a deque with remaining dimensions
remaining_sizes: collections.deque[int | torch.SymInt] = collections.deque(
self.fake_value.size(i)
for i in range(inputs_consumed, self.fake_value.ndim)
)
if self.origin.is_device():
# On device, just append the sizes directly
output_sizes.extend(remaining_sizes)
else:
# On host, use the helper to allocate reduction dimensions
from helion._compiler.indexing_strategy import (
_append_remaining_dimensions,
)

_append_remaining_dimensions(remaining_sizes, output_sizes, env)
return output_sizes

def propagate_setitem(
Expand Down
6 changes: 5 additions & 1 deletion helion/language/_tracing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def _host_tensor(debug_name: str) -> torch.Tensor:

@_decorators.codegen(_host_tensor)
def _(state: CodegenState) -> ast.AST:
return expr_from_string("_host_tensor") # should be unused
# Get the tensor from the FX node metadata
tensor = state.fx_node.meta["val"] # pyright: ignore[reportOptionalMemberAccess]
# Get the tensor argument name from the device function
tensor_arg = state.device_function.tensor_arg(tensor)
return expr_from_string(tensor_arg.name)


@has_side_effect
Expand Down
3 changes: 0 additions & 3 deletions test/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,9 +808,6 @@ def kernel(
torch.testing.assert_close(src_result, expected_src)
torch.testing.assert_close(dst_result, expected_dst)

@skipIfNormalMode(
"RankMismatch: Expected ndim=2, but got ndim=1 - tensor value assignment shape mismatch"
)
def test_tensor_value(self):
"""Test both setter from tensor value and getter for [i]"""

Expand Down
Loading