Skip to content

Commit f9b122a

Browse files
committed
Fix tensor value assignment and partial indexing in Helion
stack-info: PR: #439, branch: yf225/stack/55
1 parent 3976471 commit f9b122a

File tree

4 files changed

+135
-12
lines changed

4 files changed

+135
-12
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..runtime.config import Config
2727
from .device_function import TensorDescriptorArg
2828
from .inductor_lowering import CodegenState
29+
from .tile_dispatch import TileStrategyDispatch
2930

3031
SymIntLike = torch.SymInt | int
3132
ShapeLike = Sequence[SymIntLike]
@@ -61,6 +62,70 @@ def _normalize_negative_index(
6162
return f"({state.codegen.device_function.user_sympy_expr(sympy_expr)})"
6263

6364

65+
def _append_remaining_dimensions(
66+
input_size: collections.deque,
67+
output_size: list[int | torch.SymInt],
68+
env: CompileEnvironment,
69+
) -> None:
70+
"""Append remaining dimensions from input to output for partial indexing.
71+
72+
Args:
73+
input_size: Deque of remaining input dimensions
74+
output_size: List to append output dimensions to
75+
env: The compile environment
76+
"""
77+
while input_size:
78+
size = input_size.popleft()
79+
if size != 1:
80+
rdim = env.allocate_reduction_dimension(size)
81+
output_size.append(rdim.var)
82+
else:
83+
output_size.append(1)
84+
85+
86+
def _handle_remaining_index_dimensions(
87+
index_values: list[str],
88+
mask_values: dict[str, None],
89+
output_size: list[int | torch.SymInt],
90+
output_idx: int,
91+
fake_value: torch.Tensor,
92+
state: CodegenState,
93+
tile_strategy: TileStrategyDispatch,
94+
env: CompileEnvironment,
95+
dtype: str,
96+
) -> int:
97+
"""Handle remaining dimensions for partial indexing in SubscriptIndexing.create.
98+
99+
Args:
100+
index_values: List to append index expressions to
101+
mask_values: Dict to add mask expressions to
102+
output_size: The output shape
103+
output_idx: Current output index
104+
fake_value: The tensor being indexed
105+
state: The codegen state
106+
tile_strategy: The tile strategy
107+
env: The compile environment
108+
dtype: The triton index type
109+
110+
Returns:
111+
Updated output_idx
112+
"""
113+
while len(index_values) < fake_value.ndim:
114+
expand = tile_strategy.expand_str(output_size, output_idx)
115+
size = fake_value.size(len(index_values))
116+
if size != 1:
117+
rdim = env.allocate_reduction_dimension(size)
118+
block_idx = rdim.block_id
119+
index_var = state.codegen.index_var(block_idx)
120+
index_values.append(f"({index_var}){expand}")
121+
if mask := state.codegen.mask_var(block_idx):
122+
mask_values.setdefault(f"({mask}){expand}")
123+
else:
124+
index_values.append(f"tl.zeros([1], {dtype}){expand}")
125+
output_idx += 1
126+
return output_idx
127+
128+
64129
class IndexingStrategy:
65130
def codegen_load(
66131
self,
@@ -132,6 +197,32 @@ def codegen_store(
132197
) -> ast.AST:
133198
indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
134199
name = state.device_function.tensor_arg(fake_tensor).name
200+
201+
# Check if value is a tensor load (Name node with id matching a tensor arg)
202+
if isinstance(value, ast.Name) and hasattr(
203+
state.device_function, "_tensor_args"
204+
):
205+
# Check if this name corresponds to a tensor argument
206+
tensor = None
207+
for t, tensor_arg in state.device_function._tensor_args.items():
208+
if tensor_arg.name == value.id:
209+
tensor = t
210+
break
211+
212+
if tensor is not None:
213+
# Get the shape of the slice we're storing to
214+
output_shape = SubscriptIndexing.compute_shape(fake_tensor, subscript)
215+
if len(output_shape) == 1 and tensor.ndim == 1:
216+
# Load the entire 1D tensor
217+
value_indexing = SubscriptIndexing.create(
218+
state, tensor, [slice(None)], None
219+
)
220+
value = expr_from_string(
221+
f"tl.load({value.id} + offset, mask)",
222+
offset=value_indexing.index_expr,
223+
mask=value_indexing.mask_expr,
224+
)
225+
135226
return expr_from_string(
136227
f"tl.store({name} + offset, value, mask)",
137228
value=value,
@@ -503,7 +594,9 @@ def compute_shape(
503594
) -> list[int | torch.SymInt]:
504595
assert isinstance(tensor, torch.Tensor)
505596
assert isinstance(index, (list, tuple)), index
506-
input_size = collections.deque(tensor.size())
597+
input_size: collections.deque[int | torch.SymInt] = collections.deque(
598+
tensor.size()
599+
)
507600
output_size = []
508601
env = CompileEnvironment.current()
509602
for i, k in enumerate(index):
@@ -547,7 +640,8 @@ def compute_shape(
547640
output_size.extend(k.size())
548641
else:
549642
raise exc.InvalidIndexingType(k)
550-
assert len(input_size) == 0, "invalid subscript"
643+
# For partial indexing, append remaining dimensions to output
644+
_append_remaining_dimensions(input_size, output_size, env)
551645
return output_size
552646

553647
@staticmethod
@@ -675,6 +769,20 @@ def create(
675769
)
676770
else:
677771
raise exc.InvalidIndexingType(type(k))
772+
773+
# Handle remaining dimensions for partial indexing
774+
output_idx = _handle_remaining_index_dimensions(
775+
index_values,
776+
mask_values,
777+
output_size,
778+
output_idx,
779+
fake_value,
780+
state,
781+
tile_strategy,
782+
env,
783+
dtype,
784+
)
785+
678786
assert len(output_size) == output_idx
679787
assert len(index_values) == fake_value.ndim
680788
index_expr = []
@@ -800,7 +908,9 @@ def is_supported(
800908
if extra_mask is not None:
801909
# TODO(jansel): support block_ptr with extra_mask
802910
return False
803-
input_sizes = collections.deque(fake_tensor.size())
911+
input_sizes: collections.deque[int | torch.SymInt] = collections.deque(
912+
fake_tensor.size()
913+
)
804914
for n, k in enumerate(index):
805915
if k is None:
806916
input_size = 1

helion/_compiler/type_propagation.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ast
44
import builtins
5+
import collections
56
import contextlib
67
import dataclasses
78
import functools
@@ -485,12 +486,23 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
485486
raise exc.OverpackedTile(k)
486487
else:
487488
raise exc.InvalidIndexingType(k)
488-
if inputs_consumed != self.fake_value.ndim:
489-
raise exc.RankMismatch(
490-
self.fake_value.ndim,
491-
inputs_consumed,
492-
f"tensor shape: {tuple(self.fake_value.shape)}",
489+
# Handle partial indexing - add remaining dimensions to output
490+
if inputs_consumed < self.fake_value.ndim:
491+
# Create a deque with remaining dimensions
492+
remaining_sizes: collections.deque[int | torch.SymInt] = collections.deque(
493+
self.fake_value.size(i)
494+
for i in range(inputs_consumed, self.fake_value.ndim)
493495
)
496+
if self.origin.is_device():
497+
# On device, just append the sizes directly
498+
output_sizes.extend(remaining_sizes)
499+
else:
500+
# On host, use the helper to allocate reduction dimensions
501+
from helion._compiler.indexing_strategy import (
502+
_append_remaining_dimensions,
503+
)
504+
505+
_append_remaining_dimensions(remaining_sizes, output_sizes, env)
494506
return output_sizes
495507

496508
def propagate_setitem(

helion/language/_tracing_ops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def _host_tensor(debug_name: str) -> torch.Tensor:
6868

6969
@_decorators.codegen(_host_tensor)
7070
def _(state: CodegenState) -> ast.AST:
71-
return expr_from_string("_host_tensor") # should be unused
71+
# Get the tensor from the FX node metadata
72+
tensor = state.fx_node.meta["val"] # pyright: ignore[reportOptionalMemberAccess]
73+
# Get the tensor argument name from the device function
74+
tensor_arg = state.device_function.tensor_arg(tensor)
75+
return expr_from_string(tensor_arg.name)
7276

7377

7478
@has_side_effect

test/test_indexing.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -808,9 +808,6 @@ def kernel(
808808
torch.testing.assert_close(src_result, expected_src)
809809
torch.testing.assert_close(dst_result, expected_dst)
810810

811-
@skipIfNormalMode(
812-
"RankMismatch: Expected ndim=2, but got ndim=1 - tensor value assignment shape mismatch"
813-
)
814811
def test_tensor_value(self):
815812
"""Test both setter from tensor value and getter for [i]"""
816813

0 commit comments

Comments
 (0)