Skip to content

Commit 128393a

Browse files
committed
Fix tensor value assignment and partial indexing in Helion
stack-info: PR: #439, branch: yf225/stack/55
1 parent 07cc5e5 commit 128393a

File tree

4 files changed

+59
-11
lines changed

4 files changed

+59
-11
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,25 @@ def codegen_store(
102102
) -> ast.AST:
103103
indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
104104
name = state.device_function.tensor_arg(fake_tensor).name
105+
106+
# Check if value is a tensor load (Name node with id matching a tensor arg)
107+
if isinstance(value, ast.Name) and hasattr(state.device_function, '_tensor_args'):
108+
# Check if this name corresponds to a tensor argument
109+
for tensor, tensor_arg in state.device_function._tensor_args.items():
110+
if tensor_arg.name == value.id:
111+
# This is a tensor value, we need to load from it
112+
# Get the shape of the slice we're storing to
113+
output_shape = SubscriptIndexing.compute_shape(fake_tensor, subscript)
114+
if len(output_shape) == 1 and tensor.ndim == 1:
115+
# Load the entire 1D tensor
116+
value_indexing = SubscriptIndexing.create(state, tensor, [slice(None)], None)
117+
value = expr_from_string(
118+
f"tl.load({value.id} + offset, mask)",
119+
offset=value_indexing.index_expr,
120+
mask=value_indexing.mask_expr,
121+
)
122+
break
123+
105124
return expr_from_string(
106125
f"tl.store({name} + offset, value, mask)",
107126
value=value,
@@ -511,7 +530,14 @@ def compute_shape(
511530
output_size.extend(k.size())
512531
else:
513532
raise exc.InvalidIndexingType(k)
514-
assert len(input_size) == 0, "invalid subscript"
533+
# For partial indexing, append remaining dimensions to output
534+
while input_size:
535+
size = input_size.popleft()
536+
if size != 1:
537+
rdim = env.allocate_reduction_dimension(size)
538+
output_size.append(rdim.var)
539+
else:
540+
output_size.append(1)
515541
return output_size
516542

517543
@staticmethod
@@ -648,6 +674,22 @@ def create(
648674
)
649675
else:
650676
raise exc.InvalidIndexingType(type(k))
677+
678+
# Handle remaining dimensions for partial indexing
679+
while len(index_values) < fake_value.ndim:
680+
expand = tile_strategy.expand_str(output_size, output_idx)
681+
size = fake_value.size(len(index_values))
682+
if size != 1:
683+
rdim = env.allocate_reduction_dimension(size)
684+
block_idx = rdim.block_id
685+
index_var = state.codegen.index_var(block_idx)
686+
index_values.append(f"({index_var}){expand}")
687+
if mask := state.codegen.mask_var(block_idx):
688+
mask_values.setdefault(f"({mask}){expand}")
689+
else:
690+
index_values.append(f"tl.zeros([1], {dtype}){expand}")
691+
output_idx += 1
692+
651693
assert len(output_size) == output_idx
652694
assert len(index_values) == fake_value.ndim
653695
index_expr = []

helion/_compiler/type_propagation.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -487,12 +487,17 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
487487
raise exc.OverpackedTile(k)
488488
else:
489489
raise exc.InvalidIndexingType(k)
490-
if inputs_consumed != self.fake_value.ndim:
491-
raise exc.RankMismatch(
492-
self.fake_value.ndim,
493-
inputs_consumed,
494-
f"tensor shape: {tuple(self.fake_value.shape)}",
495-
)
490+
# Handle partial indexing - add remaining dimensions to output
491+
if inputs_consumed < self.fake_value.ndim:
492+
for i in range(inputs_consumed, self.fake_value.ndim):
493+
size = self.fake_value.size(i)
494+
if self.origin.is_device():
495+
output_sizes.append(size)
496+
elif size != 1:
497+
rdim = env.allocate_reduction_dimension(size)
498+
output_sizes.append(rdim.var)
499+
else:
500+
output_sizes.append(1)
496501
return output_sizes
497502

498503
def propagate_setitem(

helion/language/_tracing_ops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def _host_tensor(debug_name: str) -> torch.Tensor:
6868

6969
@_decorators.codegen(_host_tensor)
7070
def _(state: CodegenState) -> ast.AST:
71-
return expr_from_string("_host_tensor") # should be unused
71+
# Get the tensor from the FX node metadata
72+
tensor = state.fx_node.meta["val"] # pyright: ignore[reportOptionalMemberAccess]
73+
# Get the tensor argument name from the device function
74+
tensor_arg = state.device_function.tensor_arg(tensor)
75+
return expr_from_string(tensor_arg.name)
7276

7377

7478
@has_side_effect

test/test_indexing.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -808,9 +808,6 @@ def kernel(
808808
torch.testing.assert_close(src_result, expected_src)
809809
torch.testing.assert_close(dst_result, expected_dst)
810810

811-
@skipIfNormalMode(
812-
"RankMismatch: Expected ndim=2, but got ndim=1 - tensor value assignment shape mismatch"
813-
)
814811
def test_tensor_value(self):
815812
"""Test both setter from tensor value and getter for [i]"""
816813

0 commit comments

Comments
 (0)