Skip to content

Commit ef712a0

Browse files
committed
Add support for ellipsis (...) indexing in Helion
stack-info: PR: #437, branch: yf225/stack/53
1 parent 4718678 commit ef712a0

File tree

3 files changed

+85
-7
lines changed

3 files changed

+85
-7
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ def valid_block_size(
227227
for i, k in enumerate(subscript):
228228
if k is None:
229229
continue
230+
if k is Ellipsis:
231+
# Ellipsis is not supported in tensor descriptor mode
232+
return False
230233
size, stride = size_stride.popleft()
231234
if isinstance(k, slice):
232235
# Slices with steps are not supported in tensor descriptor mode
@@ -447,6 +450,14 @@ def codegen_store(
447450
)
448451

449452

453+
def _calculate_ellipsis_dims(
454+
index: list[object], current_index: int, total_dims: int
455+
) -> int:
456+
"""Calculate how many dimensions an ellipsis should expand to."""
457+
remaining_indices = len(index) - current_index - 1
458+
return total_dims - current_index - remaining_indices
459+
460+
450461
class SubscriptIndexing(NamedTuple):
451462
index_expr: ast.AST
452463
mask_expr: ast.AST
@@ -465,9 +476,19 @@ def compute_shape(
465476
input_size = collections.deque(tensor.size())
466477
output_size = []
467478
env = CompileEnvironment.current()
468-
for k in index:
479+
for i, k in enumerate(index):
469480
if k is None:
470481
output_size.append(1)
482+
elif k is Ellipsis:
483+
# Ellipsis expands to consume all remaining dims except those after it
484+
ellipsis_dims = _calculate_ellipsis_dims(index, i, len(tensor.size()))
485+
for _ in range(ellipsis_dims):
486+
size = input_size.popleft()
487+
if size != 1:
488+
rdim = env.allocate_reduction_dimension(size)
489+
output_size.append(rdim.var)
490+
else:
491+
output_size.append(1)
471492
elif isinstance(k, int):
472493
input_size.popleft()
473494
elif isinstance(k, torch.SymInt):
@@ -517,6 +538,22 @@ def create(
517538
for n, k in enumerate(index):
518539
if k is None:
519540
output_idx += 1
541+
elif k is Ellipsis:
542+
# Ellipsis expands to handle remaining dimensions
543+
ellipsis_dims = _calculate_ellipsis_dims(index, n, fake_value.ndim)
544+
for _ in range(ellipsis_dims):
545+
expand = tile_strategy.expand_str(output_size, output_idx)
546+
size = fake_value.size(len(index_values))
547+
if size != 1:
548+
rdim = env.allocate_reduction_dimension(size)
549+
block_idx = rdim.block_id
550+
index_var = state.codegen.index_var(block_idx)
551+
index_values.append(f"({index_var}){expand}")
552+
if mask := state.codegen.mask_var(block_idx):
553+
mask_values.setdefault(f"({mask}){expand}")
554+
else:
555+
index_values.append(f"tl.zeros([1], {dtype}){expand}")
556+
output_idx += 1
520557
elif isinstance(k, int):
521558
index_values.append(repr(k))
522559
elif isinstance(k, torch.SymInt):
@@ -729,8 +766,17 @@ def is_supported(
729766
# TODO(jansel): support block_ptr with extra_mask
730767
return False
731768
input_sizes = collections.deque(fake_tensor.size())
732-
for k in index:
733-
input_size = 1 if k is None else input_sizes.popleft()
769+
for n, k in enumerate(index):
770+
if k is None:
771+
input_size = 1
772+
elif k is Ellipsis:
773+
# Skip appropriate number of dimensions for ellipsis
774+
ellipsis_dims = _calculate_ellipsis_dims(index, n, fake_tensor.ndim)
775+
for _ in range(ellipsis_dims):
776+
input_sizes.popleft()
777+
continue
778+
else:
779+
input_size = input_sizes.popleft()
734780
if isinstance(k, torch.SymInt):
735781
symbol = k._sympy_()
736782
origin = None
@@ -780,9 +826,22 @@ def create(
780826
fake_value,
781827
reshaped_size=SubscriptIndexing.compute_shape(fake_value, index),
782828
)
783-
for k in index:
829+
for n, k in enumerate(index):
784830
if k is None:
785831
pass # handled by reshaped_size
832+
elif k is Ellipsis:
833+
# Ellipsis expands to handle remaining dimensions
834+
ellipsis_dims = _calculate_ellipsis_dims(index, n, fake_value.ndim)
835+
env = CompileEnvironment.current()
836+
for _ in range(ellipsis_dims):
837+
size = fake_value.size(len(res.offsets))
838+
if size != 1:
839+
rdim = env.allocate_reduction_dimension(size)
840+
res.offsets.append(state.codegen.offset_var(rdim.block_id))
841+
res.block_shape.append(rdim.var)
842+
else:
843+
res.offsets.append("0")
844+
res.block_shape.append(1)
786845
elif isinstance(k, int):
787846
res.offsets.append(repr(k))
788847
res.block_shape.append(1)

helion/_compiler/type_propagation.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,28 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
433433
inputs_consumed += 1
434434
elif k.value is None:
435435
output_sizes.append(1)
436+
elif k.value is Ellipsis:
437+
# Ellipsis consumes all remaining dimensions except those after it
438+
# Count how many indices come after the ellipsis
439+
remaining_keys = sum(
440+
1
441+
for key in keys[keys.index(k) + 1 :]
442+
if not (isinstance(key, LiteralType) and key.value is None)
443+
)
444+
# Consume all dimensions except those needed for remaining keys
445+
ellipsis_dims = (
446+
self.fake_value.ndim - inputs_consumed - remaining_keys
447+
)
448+
for _ in range(ellipsis_dims):
449+
size = self.fake_value.size(inputs_consumed)
450+
inputs_consumed += 1
451+
if self.origin.is_device():
452+
output_sizes.append(size)
453+
elif size != 1:
454+
rdim = env.allocate_reduction_dimension(size)
455+
output_sizes.append(rdim.var)
456+
else:
457+
output_sizes.append(1)
436458
else:
437459
raise exc.InvalidIndexingType(k)
438460
elif isinstance(k, SymIntType):

test/test_indexing.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -759,9 +759,6 @@ def kernel(
759759
torch.testing.assert_close(src_result, expected_src)
760760
torch.testing.assert_close(dst_result, expected_dst)
761761

762-
@skipIfNormalMode(
763-
"RankMismatch: Cannot assign a tensor of rank 2 to a buffer of rank 3"
764-
)
765762
def test_ellipsis_indexing(self):
766763
"""Test both setter from scalar and getter for [..., i]"""
767764

0 commit comments

Comments
 (0)