Skip to content

Commit f812f6e

Browse files
committed
Add support for ellipsis (...) indexing in Helion
stack-info: PR: #437, branch: yf225/stack/53
1 parent 4718678 commit f812f6e

File tree

3 files changed

+81
-7
lines changed

3 files changed

+81
-7
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ def valid_block_size(
227227
for i, k in enumerate(subscript):
228228
if k is None:
229229
continue
230+
if k is Ellipsis:
231+
# Ellipsis is not supported in tensor descriptor mode
232+
return False
230233
size, stride = size_stride.popleft()
231234
if isinstance(k, slice):
232235
# Slices with steps are not supported in tensor descriptor mode
@@ -465,9 +468,20 @@ def compute_shape(
465468
input_size = collections.deque(tensor.size())
466469
output_size = []
467470
env = CompileEnvironment.current()
468-
for k in index:
471+
for i, k in enumerate(index):
469472
if k is None:
470473
output_size.append(1)
474+
elif k is Ellipsis:
475+
# Ellipsis expands to consume all remaining dims except those after it
476+
remaining_indices = len(index) - i - 1
477+
ellipsis_dims = len(input_size) - remaining_indices
478+
for _ in range(ellipsis_dims):
479+
size = input_size.popleft()
480+
if size != 1:
481+
rdim = env.allocate_reduction_dimension(size)
482+
output_size.append(rdim.var)
483+
else:
484+
output_size.append(1)
471485
elif isinstance(k, int):
472486
input_size.popleft()
473487
elif isinstance(k, torch.SymInt):
@@ -517,6 +531,23 @@ def create(
517531
for n, k in enumerate(index):
518532
if k is None:
519533
output_idx += 1
534+
elif k is Ellipsis:
535+
# Ellipsis expands to handle remaining dimensions
536+
remaining_indices = len(index) - n - 1
537+
ellipsis_dims = fake_value.ndim - len(index_values) - remaining_indices
538+
for dim_offset in range(ellipsis_dims):
539+
expand = tile_strategy.expand_str(output_size, output_idx)
540+
size = fake_value.size(len(index_values))
541+
if size != 1:
542+
rdim = env.allocate_reduction_dimension(size)
543+
block_idx = rdim.block_id
544+
index_var = state.codegen.index_var(block_idx)
545+
index_values.append(f"({index_var}){expand}")
546+
if mask := state.codegen.mask_var(block_idx):
547+
mask_values.setdefault(f"({mask}){expand}")
548+
else:
549+
index_values.append(f"tl.zeros([1], {dtype}){expand}")
550+
output_idx += 1
520551
elif isinstance(k, int):
521552
index_values.append(repr(k))
522553
elif isinstance(k, torch.SymInt):
@@ -729,8 +760,18 @@ def is_supported(
729760
# TODO(jansel): support block_ptr with extra_mask
730761
return False
731762
input_sizes = collections.deque(fake_tensor.size())
732-
for k in index:
733-
input_size = 1 if k is None else input_sizes.popleft()
763+
for n, k in enumerate(index):
764+
if k is None:
765+
input_size = 1
766+
elif k is Ellipsis:
767+
# Skip appropriate number of dimensions for ellipsis
768+
remaining_indices = len(index) - n - 1
769+
ellipsis_dims = len(input_sizes) - remaining_indices
770+
for _ in range(ellipsis_dims):
771+
input_sizes.popleft()
772+
continue
773+
else:
774+
input_size = input_sizes.popleft()
734775
if isinstance(k, torch.SymInt):
735776
symbol = k._sympy_()
736777
origin = None
@@ -780,9 +821,23 @@ def create(
780821
fake_value,
781822
reshaped_size=SubscriptIndexing.compute_shape(fake_value, index),
782823
)
783-
for k in index:
824+
for n, k in enumerate(index):
784825
if k is None:
785826
pass # handled by reshaped_size
827+
elif k is Ellipsis:
828+
# Ellipsis expands to handle remaining dimensions
829+
remaining_indices = len(index) - n - 1
830+
ellipsis_dims = fake_value.ndim - len(res.offsets) - remaining_indices
831+
for _ in range(ellipsis_dims):
832+
size = fake_value.size(len(res.offsets))
833+
if size != 1:
834+
env = CompileEnvironment.current()
835+
rdim = env.allocate_reduction_dimension(size)
836+
res.offsets.append(state.codegen.offset_var(rdim.block_id))
837+
res.block_shape.append(rdim.var)
838+
else:
839+
res.offsets.append("0")
840+
res.block_shape.append(1)
786841
elif isinstance(k, int):
787842
res.offsets.append(repr(k))
788843
res.block_shape.append(1)

helion/_compiler/type_propagation.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,28 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
433433
inputs_consumed += 1
434434
elif k.value is None:
435435
output_sizes.append(1)
436+
elif k.value is Ellipsis:
437+
# Ellipsis consumes all remaining dimensions except those after it
438+
# Count how many indices come after the ellipsis
439+
remaining_keys = sum(
440+
1
441+
for key in keys[keys.index(k) + 1 :]
442+
if not (isinstance(key, LiteralType) and key.value is None)
443+
)
444+
# Consume all dimensions except those needed for remaining keys
445+
ellipsis_dims = (
446+
self.fake_value.ndim - inputs_consumed - remaining_keys
447+
)
448+
for _ in range(ellipsis_dims):
449+
size = self.fake_value.size(inputs_consumed)
450+
inputs_consumed += 1
451+
if self.origin.is_device():
452+
output_sizes.append(size)
453+
elif size != 1:
454+
rdim = env.allocate_reduction_dimension(size)
455+
output_sizes.append(rdim.var)
456+
else:
457+
output_sizes.append(1)
436458
else:
437459
raise exc.InvalidIndexingType(k)
438460
elif isinstance(k, SymIntType):

test/test_indexing.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -759,9 +759,6 @@ def kernel(
759759
torch.testing.assert_close(src_result, expected_src)
760760
torch.testing.assert_close(dst_result, expected_dst)
761761

762-
@skipIfNormalMode(
763-
"RankMismatch: Cannot assign a tensor of rank 2 to a buffer of rank 3"
764-
)
765762
def test_ellipsis_indexing(self):
766763
"""Test both setter from scalar and getter for [..., i]"""
767764

0 commit comments

Comments
 (0)