From f9b122af050e1191fce0526eade91f7fb9006d13 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 4 Aug 2025 15:37:36 -0700 Subject: [PATCH] Fix tensor value assignment and partial indexing in Helion stack-info: PR: https://github.com/pytorch/helion/pull/439, branch: yf225/stack/55 --- helion/_compiler/indexing_strategy.py | 116 +++++++++++++++++++++++++- helion/_compiler/type_propagation.py | 22 +++-- helion/language/_tracing_ops.py | 6 +- test/test_indexing.py | 3 - 4 files changed, 135 insertions(+), 12 deletions(-) diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index 06d06155..bb7491ce 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -26,6 +26,7 @@ from ..runtime.config import Config from .device_function import TensorDescriptorArg from .inductor_lowering import CodegenState + from .tile_dispatch import TileStrategyDispatch SymIntLike = torch.SymInt | int ShapeLike = Sequence[SymIntLike] @@ -61,6 +62,70 @@ def _normalize_negative_index( return f"({state.codegen.device_function.user_sympy_expr(sympy_expr)})" +def _append_remaining_dimensions( + input_size: collections.deque, + output_size: list[int | torch.SymInt], + env: CompileEnvironment, +) -> None: + """Append remaining dimensions from input to output for partial indexing. + + Args: + input_size: Deque of remaining input dimensions + output_size: List to append output dimensions to + env: The compile environment + """ + while input_size: + size = input_size.popleft() + if size != 1: + rdim = env.allocate_reduction_dimension(size) + output_size.append(rdim.var) + else: + output_size.append(1) + + +def _handle_remaining_index_dimensions( + index_values: list[str], + mask_values: dict[str, None], + output_size: list[int | torch.SymInt], + output_idx: int, + fake_value: torch.Tensor, + state: CodegenState, + tile_strategy: TileStrategyDispatch, + env: CompileEnvironment, + dtype: str, +) -> int: + """Handle remaining dimensions for partial indexing in SubscriptIndexing.create. + + Args: + index_values: List to append index expressions to + mask_values: Dict to add mask expressions to + output_size: The output shape + output_idx: Current output index + fake_value: The tensor being indexed + state: The codegen state + tile_strategy: The tile strategy + env: The compile environment + dtype: The triton index type + + Returns: + Updated output_idx + """ + while len(index_values) < fake_value.ndim: + expand = tile_strategy.expand_str(output_size, output_idx) + size = fake_value.size(len(index_values)) + if size != 1: + rdim = env.allocate_reduction_dimension(size) + block_idx = rdim.block_id + index_var = state.codegen.index_var(block_idx) + index_values.append(f"({index_var}){expand}") + if mask := state.codegen.mask_var(block_idx): + mask_values.setdefault(f"({mask}){expand}") + else: + index_values.append(f"tl.zeros([1], {dtype}){expand}") + output_idx += 1 + return output_idx + + class IndexingStrategy: def codegen_load( self, @@ -132,6 +197,32 @@ def codegen_store( ) -> ast.AST: indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask) name = state.device_function.tensor_arg(fake_tensor).name + + # Check if value is a tensor load (Name node with id matching a tensor arg) + if isinstance(value, ast.Name) and hasattr( + state.device_function, "_tensor_args" + ): + # Check if this name corresponds to a tensor argument + tensor = None + for t, tensor_arg in state.device_function._tensor_args.items(): + if tensor_arg.name == value.id: + tensor = t + break + + if tensor is not None: + # Get the shape of the slice we're storing to + output_shape = SubscriptIndexing.compute_shape(fake_tensor, subscript) + if len(output_shape) == 1 and tensor.ndim == 1: + # Load the entire 1D tensor + value_indexing = SubscriptIndexing.create( + state, tensor, [slice(None)], None + ) + value = expr_from_string( + f"tl.load({value.id} + offset, mask)", + offset=value_indexing.index_expr, + mask=value_indexing.mask_expr, + ) + return expr_from_string( f"tl.store({name} + offset, value, mask)", value=value, @@ -503,7 +594,9 @@ def compute_shape( ) -> list[int | torch.SymInt]: assert isinstance(tensor, torch.Tensor) assert isinstance(index, (list, tuple)), index - input_size = collections.deque(tensor.size()) + input_size: collections.deque[int | torch.SymInt] = collections.deque( + tensor.size() + ) output_size = [] env = CompileEnvironment.current() for i, k in enumerate(index): @@ -547,7 +640,8 @@ def compute_shape( output_size.extend(k.size()) else: raise exc.InvalidIndexingType(k) - assert len(input_size) == 0, "invalid subscript" + # For partial indexing, append remaining dimensions to output + _append_remaining_dimensions(input_size, output_size, env) return output_size @staticmethod @@ -675,6 +769,20 @@ def create( ) else: raise exc.InvalidIndexingType(type(k)) + + # Handle remaining dimensions for partial indexing + output_idx = _handle_remaining_index_dimensions( + index_values, + mask_values, + output_size, + output_idx, + fake_value, + state, + tile_strategy, + env, + dtype, + ) + assert len(output_size) == output_idx assert len(index_values) == fake_value.ndim index_expr = [] @@ -800,7 +908,9 @@ def is_supported( if extra_mask is not None: # TODO(jansel): support block_ptr with extra_mask return False - input_sizes = collections.deque(fake_tensor.size()) + input_sizes: collections.deque[int | torch.SymInt] = collections.deque( + fake_tensor.size() + ) for n, k in enumerate(index): if k is None: input_size = 1 diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 97a042b2..9355fefb 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -2,6 +2,7 @@ import ast import builtins +import collections import contextlib import dataclasses import functools @@ -485,12 +486,23 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]: raise exc.OverpackedTile(k) else: raise exc.InvalidIndexingType(k) - if inputs_consumed != self.fake_value.ndim: - raise exc.RankMismatch( - self.fake_value.ndim, - inputs_consumed, - f"tensor shape: {tuple(self.fake_value.shape)}", + # Handle partial indexing - add remaining dimensions to output + if inputs_consumed < self.fake_value.ndim: + # Create a deque with remaining dimensions + remaining_sizes: collections.deque[int | torch.SymInt] = collections.deque( + self.fake_value.size(i) + for i in range(inputs_consumed, self.fake_value.ndim) ) + if self.origin.is_device(): + # On device, just append the sizes directly + output_sizes.extend(remaining_sizes) + else: + # On host, use the helper to allocate reduction dimensions + from helion._compiler.indexing_strategy import ( + _append_remaining_dimensions, + ) + + _append_remaining_dimensions(remaining_sizes, output_sizes, env) return output_sizes def propagate_setitem( diff --git a/helion/language/_tracing_ops.py b/helion/language/_tracing_ops.py index 2f5e42b9..8c90d8b7 100644 --- a/helion/language/_tracing_ops.py +++ b/helion/language/_tracing_ops.py @@ -68,7 +68,11 @@ def _host_tensor(debug_name: str) -> torch.Tensor: @_decorators.codegen(_host_tensor) def _(state: CodegenState) -> ast.AST: - return expr_from_string("_host_tensor") # should be unused + # Get the tensor from the FX node metadata + tensor = state.fx_node.meta["val"] # pyright: ignore[reportOptionalMemberAccess] + # Get the tensor argument name from the device function + tensor_arg = state.device_function.tensor_arg(tensor) + return expr_from_string(tensor_arg.name) @has_side_effect diff --git a/test/test_indexing.py b/test/test_indexing.py index 3ca95fd6..f63cde16 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -808,9 +808,6 @@ def kernel( torch.testing.assert_close(src_result, expected_src) torch.testing.assert_close(dst_result, expected_dst) - @skipIfNormalMode( - "RankMismatch: Expected ndim=2, but got ndim=1 - tensor value assignment shape mismatch" - ) def test_tensor_value(self): """Test both setter from tensor value and getter for [i]"""