Skip to content

Fix static slice indexing with explicit start/stop bounds #440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: yf225/stack/55
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .host_function import HostFunction
from .tile_strategy import DeviceLoopState
from .utils import compute_slice_size
from .utils import get_slice_start
from .variable_origin import BlockSizeOrigin

if TYPE_CHECKING:
Expand Down Expand Up @@ -126,6 +127,30 @@ def _handle_remaining_index_dimensions(
return output_idx


def _generate_slice_index(
start: int | torch.SymInt,
index_var: str,
expand: str,
step: int | None = None,
) -> str:
"""Generate slice index expression with optional step."""
if step is not None:
# Strided index: start + index * step
return f"({start} + ({index_var}) * {step}){expand}"
if start != 0:
# Index with offset: start + index
return f"({start} + ({index_var})){expand}"
# Simple index
return f"({index_var}){expand}"


def _generate_offset_expr(start: int | torch.SymInt, offset: str) -> str:
"""Generate offset expression with optional start."""
if start != 0:
return f"({start} + {offset})"
return offset


class IndexingStrategy:
def codegen_load(
self,
Expand Down Expand Up @@ -627,7 +652,6 @@ def compute_shape(
size = input_size.popleft()
# Handle slices with steps
slice_size = compute_slice_size(k, size)

if slice_size != 1:
rdim = env.allocate_reduction_dimension(slice_size)
output_size.append(rdim.var)
Expand Down Expand Up @@ -719,25 +743,29 @@ def create(
rdim = env.allocate_reduction_dimension(slice_size)
block_idx = rdim.block_id
index_var = state.codegen.index_var(block_idx)
# Generate strided index: start + index * step
index_values.append(
f"({start} + ({index_var}) * {step}){expand}"
_generate_slice_index(start, index_var, expand, step)
)
if mask := state.codegen.mask_var(block_idx):
mask_values.setdefault(f"({mask}){expand}")
else:
index_values.append(f"{start}{expand}")
else:
# Full slice or slice without step
if size != 1:
rdim = env.allocate_reduction_dimension(size)
# Handle slices with start/stop but no step
start = get_slice_start(k)
slice_size = compute_slice_size(k, size)

if slice_size != 1:
rdim = env.allocate_reduction_dimension(slice_size)
block_idx = rdim.block_id
index_var = state.codegen.index_var(block_idx)
index_values.append(f"({index_var}){expand}")
index_values.append(
_generate_slice_index(start, index_var, expand)
)
if mask := state.codegen.mask_var(block_idx):
mask_values.setdefault(f"({mask}){expand}")
else:
index_values.append(f"tl.zeros([1], {dtype}){expand}")
index_values.append(f"{start}{expand}")
output_idx += 1
elif isinstance(k, torch.Tensor) and k.ndim == 1:
expand = tile_strategy.expand_str(output_size, output_idx)
Expand Down Expand Up @@ -1025,8 +1053,19 @@ def create(
res.offsets.append(state.codegen.offset_var(rdim.block_id))
res.block_shape.append(rdim.var)
else:
res.offsets.append("0")
res.block_shape.append(1)
# Handle slices with start/stop but no step
start = get_slice_start(k)
slice_size = compute_slice_size(k, size)

if slice_size != 1:
env = CompileEnvironment.current()
rdim = env.allocate_reduction_dimension(slice_size)
offset = state.codegen.offset_var(rdim.block_id)
res.offsets.append(_generate_offset_expr(start, offset))
res.block_shape.append(rdim.var)
else:
res.offsets.append(str(start))
res.block_shape.append(1)
else:
raise exc.InvalidIndexingType(k)
res.validate()
Expand Down
6 changes: 3 additions & 3 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,6 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:

# For slices with steps, we need to calculate the output size differently
output_size = compute_slice_size(slice_obj, size)

if self.origin.is_device():
output_sizes.append(output_size)
elif output_size != 1:
Expand Down Expand Up @@ -515,8 +514,9 @@ def propagate_setitem(
lhs_rank = len(lhs_shape)
if isinstance(value, TensorType):
rhs_rank = value.fake_value.ndim
# Allow scalar tensors (rank 0) to be assigned to any rank (broadcasts)
if rhs_rank != 0 and lhs_rank != rhs_rank:
rhs_numel = value.fake_value.numel()
# Allow scalar tensors (rank 0) or single-element tensors to be assigned to any rank (broadcasts)
if rhs_rank != 0 and rhs_numel != 1 and lhs_rank != rhs_rank:
raise exc.RankMismatch(
lhs_rank,
rhs_rank,
Expand Down
11 changes: 9 additions & 2 deletions helion/_compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,12 @@ def compute_slice_size(
stop = slice_obj.stop if slice_obj.stop is not None else original_size
step = slice_obj.step
return (stop - start + step - 1) // step
# Full slice or slice without step
return original_size
# Calculate slice size based on start/stop
start = slice_obj.start if slice_obj.start is not None else 0
stop = slice_obj.stop if slice_obj.stop is not None else original_size
return stop - start


def get_slice_start(slice_obj: slice) -> int:
"""Get the start index of a slice, defaulting to 0."""
return slice_obj.start if slice_obj.start is not None else 0
3 changes: 1 addition & 2 deletions test/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,6 @@ def kernel(
torch.testing.assert_close(src_result, expected_src)
torch.testing.assert_close(dst_result, expected_dst)

@skipIfNormalMode("InternalError: Unexpected type <class 'slice'>")
def test_range_slice(self):
"""Test both setter from scalar and getter for [10:20]"""

Expand All @@ -904,7 +903,7 @@ def kernel(
torch.testing.assert_close(dst_result, expected_dst)

@skipIfNormalMode(
"InternalError: AssertionError in type_propagation.py - slice indexing error"
"Dynamic slices (i:i+1) are not supported - FX cannot trace symbolic slice indices"
)
def test_range_slice_dynamic(self):
"""Test both [i:i+1] = scalar and [i] = [i:i+1] patterns"""
Expand Down
Loading