Skip to content

Commit 2ea0b32

Browse files
committed
Fix static slice indexing with explicit start/stop bounds
stack-info: PR: #440, branch: yf225/stack/56
1 parent f9b122a commit 2ea0b32

File tree

4 files changed

+62
-17
lines changed

4 files changed

+62
-17
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .host_function import HostFunction
1919
from .tile_strategy import DeviceLoopState
2020
from .utils import compute_slice_size
21+
from .utils import get_slice_start
2122
from .variable_origin import BlockSizeOrigin
2223

2324
if TYPE_CHECKING:
@@ -126,6 +127,30 @@ def _handle_remaining_index_dimensions(
126127
return output_idx
127128

128129

130+
def _generate_slice_index(
131+
start: int | torch.SymInt,
132+
index_var: str,
133+
expand: str,
134+
step: int | None = None,
135+
) -> str:
136+
"""Generate slice index expression with optional step."""
137+
if step is not None:
138+
# Strided index: start + index * step
139+
return f"({start} + ({index_var}) * {step}){expand}"
140+
if start != 0:
141+
# Index with offset: start + index
142+
return f"({start} + ({index_var})){expand}"
143+
# Simple index
144+
return f"({index_var}){expand}"
145+
146+
147+
def _generate_offset_expr(start: int | torch.SymInt, offset: str) -> str:
148+
"""Generate offset expression with optional start."""
149+
if start != 0:
150+
return f"({start} + {offset})"
151+
return offset
152+
153+
129154
class IndexingStrategy:
130155
def codegen_load(
131156
self,
@@ -627,7 +652,6 @@ def compute_shape(
627652
size = input_size.popleft()
628653
# Handle slices with steps
629654
slice_size = compute_slice_size(k, size)
630-
631655
if slice_size != 1:
632656
rdim = env.allocate_reduction_dimension(slice_size)
633657
output_size.append(rdim.var)
@@ -719,25 +743,29 @@ def create(
719743
rdim = env.allocate_reduction_dimension(slice_size)
720744
block_idx = rdim.block_id
721745
index_var = state.codegen.index_var(block_idx)
722-
# Generate strided index: start + index * step
723746
index_values.append(
724-
f"({start} + ({index_var}) * {step}){expand}"
747+
_generate_slice_index(start, index_var, expand, step)
725748
)
726749
if mask := state.codegen.mask_var(block_idx):
727750
mask_values.setdefault(f"({mask}){expand}")
728751
else:
729752
index_values.append(f"{start}{expand}")
730753
else:
731-
# Full slice or slice without step
732-
if size != 1:
733-
rdim = env.allocate_reduction_dimension(size)
754+
# Handle slices with start/stop but no step
755+
start = get_slice_start(k)
756+
slice_size = compute_slice_size(k, size)
757+
758+
if slice_size != 1:
759+
rdim = env.allocate_reduction_dimension(slice_size)
734760
block_idx = rdim.block_id
735761
index_var = state.codegen.index_var(block_idx)
736-
index_values.append(f"({index_var}){expand}")
762+
index_values.append(
763+
_generate_slice_index(start, index_var, expand)
764+
)
737765
if mask := state.codegen.mask_var(block_idx):
738766
mask_values.setdefault(f"({mask}){expand}")
739767
else:
740-
index_values.append(f"tl.zeros([1], {dtype}){expand}")
768+
index_values.append(f"{start}{expand}")
741769
output_idx += 1
742770
elif isinstance(k, torch.Tensor) and k.ndim == 1:
743771
expand = tile_strategy.expand_str(output_size, output_idx)
@@ -1025,8 +1053,19 @@ def create(
10251053
res.offsets.append(state.codegen.offset_var(rdim.block_id))
10261054
res.block_shape.append(rdim.var)
10271055
else:
1028-
res.offsets.append("0")
1029-
res.block_shape.append(1)
1056+
# Handle slices with start/stop but no step
1057+
start = get_slice_start(k)
1058+
slice_size = compute_slice_size(k, size)
1059+
1060+
if slice_size != 1:
1061+
env = CompileEnvironment.current()
1062+
rdim = env.allocate_reduction_dimension(slice_size)
1063+
offset = state.codegen.offset_var(rdim.block_id)
1064+
res.offsets.append(_generate_offset_expr(start, offset))
1065+
res.block_shape.append(rdim.var)
1066+
else:
1067+
res.offsets.append(str(start))
1068+
res.block_shape.append(1)
10301069
else:
10311070
raise exc.InvalidIndexingType(k)
10321071
res.validate()

helion/_compiler/type_propagation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,6 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
466466

467467
# For slices with steps, we need to calculate the output size differently
468468
output_size = compute_slice_size(slice_obj, size)
469-
470469
if self.origin.is_device():
471470
output_sizes.append(output_size)
472471
elif output_size != 1:
@@ -515,8 +514,9 @@ def propagate_setitem(
515514
lhs_rank = len(lhs_shape)
516515
if isinstance(value, TensorType):
517516
rhs_rank = value.fake_value.ndim
518-
# Allow scalar tensors (rank 0) to be assigned to any rank (broadcasts)
519-
if rhs_rank != 0 and lhs_rank != rhs_rank:
517+
rhs_numel = value.fake_value.numel()
518+
# Allow scalar tensors (rank 0) or single-element tensors to be assigned to any rank (broadcasts)
519+
if rhs_rank != 0 and rhs_numel != 1 and lhs_rank != rhs_rank:
520520
raise exc.RankMismatch(
521521
lhs_rank,
522522
rhs_rank,

helion/_compiler/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,12 @@ def compute_slice_size(
2525
stop = slice_obj.stop if slice_obj.stop is not None else original_size
2626
step = slice_obj.step
2727
return (stop - start + step - 1) // step
28-
# Full slice or slice without step
29-
return original_size
28+
# Calculate slice size based on start/stop
29+
start = slice_obj.start if slice_obj.start is not None else 0
30+
stop = slice_obj.stop if slice_obj.stop is not None else original_size
31+
return stop - start
32+
33+
34+
def get_slice_start(slice_obj: slice) -> int:
35+
"""Get the start index of a slice, defaulting to 0."""
36+
return slice_obj.start if slice_obj.start is not None else 0

test/test_indexing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,6 @@ def kernel(
877877
torch.testing.assert_close(src_result, expected_src)
878878
torch.testing.assert_close(dst_result, expected_dst)
879879

880-
@skipIfNormalMode("InternalError: Unexpected type <class 'slice'>")
881880
def test_range_slice(self):
882881
"""Test both setter from scalar and getter for [10:20]"""
883882

@@ -904,7 +903,7 @@ def kernel(
904903
torch.testing.assert_close(dst_result, expected_dst)
905904

906905
@skipIfNormalMode(
907-
"InternalError: AssertionError in type_propagation.py - slice indexing error"
906+
"Dynamic slices (i:i+1) are not supported - FX cannot trace symbolic slice indices"
908907
)
909908
def test_range_slice_dynamic(self):
910909
"""Test both [i:i+1] = scalar and [i] = [i:i+1] patterns"""

0 commit comments

Comments
 (0)