Skip to content

Commit 1463e9a

Browse files
committed
Fix static slice indexing with explicit start/stop bounds
stack-info: PR: #440, branch: yf225/stack/56
1 parent 630b8ef commit 1463e9a

File tree

4 files changed

+62
-17
lines changed

4 files changed

+62
-17
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .host_function import HostFunction
1919
from .tile_strategy import DeviceLoopState
2020
from .utils import compute_slice_size
21+
from .utils import get_slice_start
2122
from .variable_origin import BlockSizeOrigin
2223

2324
if TYPE_CHECKING:
@@ -126,6 +127,30 @@ def _handle_remaining_index_dimensions(
126127
return output_idx
127128

128129

130+
def _generate_slice_index(
131+
start: int | torch.SymInt,
132+
index_var: str,
133+
expand: str,
134+
step: int | None = None,
135+
) -> str:
136+
"""Generate slice index expression with optional step."""
137+
if step is not None:
138+
# Strided index: start + index * step
139+
return f"({start} + ({index_var}) * {step}){expand}"
140+
if start != 0:
141+
# Index with offset: start + index
142+
return f"({start} + ({index_var})){expand}"
143+
# Simple index
144+
return f"({index_var}){expand}"
145+
146+
147+
def _generate_offset_expr(start: int | torch.SymInt, offset: str) -> str:
148+
"""Generate offset expression with optional start."""
149+
if start != 0:
150+
return f"({start} + {offset})"
151+
return offset
152+
153+
129154
class IndexingStrategy:
130155
def codegen_load(
131156
self,
@@ -628,7 +653,6 @@ def compute_shape(
628653
size = input_size.popleft()
629654
# Handle slices with steps
630655
slice_size = compute_slice_size(k, size)
631-
632656
if slice_size != 1:
633657
rdim = env.allocate_reduction_dimension(slice_size)
634658
output_size.append(rdim.var)
@@ -721,25 +745,29 @@ def create(
721745
rdim = env.allocate_reduction_dimension(slice_size)
722746
block_idx = rdim.block_id
723747
index_var = state.codegen.index_var(block_idx)
724-
# Generate strided index: start + index * step
725748
index_values.append(
726-
f"({start} + ({index_var}) * {step}){expand}"
749+
_generate_slice_index(start, index_var, expand, step)
727750
)
728751
if mask := state.codegen.mask_var(block_idx):
729752
mask_values.setdefault(f"({mask}){expand}")
730753
else:
731754
index_values.append(f"{start}{expand}")
732755
else:
733-
# Full slice or slice without step
734-
if size != 1:
735-
rdim = env.allocate_reduction_dimension(size)
756+
# Handle slices with start/stop but no step
757+
start = get_slice_start(k)
758+
slice_size = compute_slice_size(k, size)
759+
760+
if slice_size != 1:
761+
rdim = env.allocate_reduction_dimension(slice_size)
736762
block_idx = rdim.block_id
737763
index_var = state.codegen.index_var(block_idx)
738-
index_values.append(f"({index_var}){expand}")
764+
index_values.append(
765+
_generate_slice_index(start, index_var, expand)
766+
)
739767
if mask := state.codegen.mask_var(block_idx):
740768
mask_values.setdefault(f"({mask}){expand}")
741769
else:
742-
index_values.append(f"tl.zeros([1], {dtype}){expand}")
770+
index_values.append(f"{start}{expand}")
743771
output_idx += 1
744772
elif isinstance(k, torch.Tensor) and k.ndim == 1:
745773
expand = tile_strategy.expand_str(output_size, output_idx)
@@ -1029,8 +1057,19 @@ def create(
10291057
res.offsets.append(state.codegen.offset_var(rdim.block_id))
10301058
res.block_shape.append(rdim.var)
10311059
else:
1032-
res.offsets.append("0")
1033-
res.block_shape.append(1)
1060+
# Handle slices with start/stop but no step
1061+
start = get_slice_start(k)
1062+
slice_size = compute_slice_size(k, size)
1063+
1064+
if slice_size != 1:
1065+
env = CompileEnvironment.current()
1066+
rdim = env.allocate_reduction_dimension(slice_size)
1067+
offset = state.codegen.offset_var(rdim.block_id)
1068+
res.offsets.append(_generate_offset_expr(start, offset))
1069+
res.block_shape.append(rdim.var)
1070+
else:
1071+
res.offsets.append(str(start))
1072+
res.block_shape.append(1)
10341073
else:
10351074
raise exc.InvalidIndexingType(k)
10361075
res.validate()

helion/_compiler/type_propagation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,6 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
468468

469469
# For slices with steps, we need to calculate the output size differently
470470
output_size = compute_slice_size(slice_obj, size)
471-
472471
if self.origin.is_device():
473472
output_sizes.append(output_size)
474473
elif output_size != 1:
@@ -517,8 +516,9 @@ def propagate_setitem(
517516
lhs_rank = len(lhs_shape)
518517
if isinstance(value, TensorType):
519518
rhs_rank = value.fake_value.ndim
520-
# Allow scalar tensors (rank 0) to be assigned to any rank (broadcasts)
521-
if rhs_rank != 0 and lhs_rank != rhs_rank:
519+
rhs_numel = value.fake_value.numel()
520+
# Allow scalar tensors (rank 0) or single-element tensors to be assigned to any rank (broadcasts)
521+
if rhs_rank != 0 and rhs_numel != 1 and lhs_rank != rhs_rank:
522522
raise exc.RankMismatch(
523523
lhs_rank,
524524
rhs_rank,

helion/_compiler/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,12 @@ def compute_slice_size(
2525
stop = slice_obj.stop if slice_obj.stop is not None else original_size
2626
step = slice_obj.step
2727
return (stop - start + step - 1) // step
28-
# Full slice or slice without step
29-
return original_size
28+
# Calculate slice size based on start/stop
29+
start = slice_obj.start if slice_obj.start is not None else 0
30+
stop = slice_obj.stop if slice_obj.stop is not None else original_size
31+
return stop - start
32+
33+
34+
def get_slice_start(slice_obj: slice) -> int:
35+
"""Get the start index of a slice, defaulting to 0."""
36+
return slice_obj.start if slice_obj.start is not None else 0

test/test_indexing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,6 @@ def kernel(
877877
torch.testing.assert_close(src_result, expected_src)
878878
torch.testing.assert_close(dst_result, expected_dst)
879879

880-
@skipIfNormalMode("InternalError: Unexpected type <class 'slice'>")
881880
def test_range_slice(self):
882881
"""Test both setter from scalar and getter for [10:20]"""
883882

@@ -904,7 +903,7 @@ def kernel(
904903
torch.testing.assert_close(dst_result, expected_dst)
905904

906905
@skipIfNormalMode(
907-
"InternalError: AssertionError in type_propagation.py - slice indexing error"
906+
"Dynamic slices (i:i+1) are not supported - FX cannot trace symbolic slice indices"
908907
)
909908
def test_range_slice_dynamic(self):
910909
"""Test both [i:i+1] = scalar and [i] = [i:i+1] patterns"""

0 commit comments

Comments
 (0)