diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index a98b0a21..636e76bc 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -227,6 +227,9 @@ def valid_block_size( for i, k in enumerate(subscript): if k is None: continue + if k is Ellipsis: + # Ellipsis is not supported in tensor descriptor mode + return False size, stride = size_stride.popleft() if isinstance(k, slice): # Slices with steps are not supported in tensor descriptor mode @@ -447,6 +450,14 @@ def codegen_store( ) +def _calculate_ellipsis_dims( + index: list[object], current_index: int, total_dims: int +) -> int: + """Calculate how many dimensions an ellipsis should expand to.""" + remaining_indices = len(index) - current_index - 1 + return total_dims - current_index - remaining_indices + + class SubscriptIndexing(NamedTuple): index_expr: ast.AST mask_expr: ast.AST @@ -465,9 +476,18 @@ def compute_shape( input_size = collections.deque(tensor.size()) output_size = [] env = CompileEnvironment.current() - for k in index: + for i, k in enumerate(index): if k is None: output_size.append(1) + elif k is Ellipsis: + ellipsis_dims = _calculate_ellipsis_dims(index, i, len(tensor.size())) + for _ in range(ellipsis_dims): + size = input_size.popleft() + if size != 1: + rdim = env.allocate_reduction_dimension(size) + output_size.append(rdim.var) + else: + output_size.append(1) elif isinstance(k, int): input_size.popleft() elif isinstance(k, torch.SymInt): @@ -517,6 +537,21 @@ def create( for n, k in enumerate(index): if k is None: output_idx += 1 + elif k is Ellipsis: + ellipsis_dims = _calculate_ellipsis_dims(index, n, fake_value.ndim) + for _ in range(ellipsis_dims): + expand = tile_strategy.expand_str(output_size, output_idx) + size = fake_value.size(len(index_values)) + if size != 1: + rdim = env.allocate_reduction_dimension(size) + block_idx = rdim.block_id + index_var = state.codegen.index_var(block_idx) + index_values.append(f"({index_var}){expand}") + if mask := state.codegen.mask_var(block_idx): + mask_values.setdefault(f"({mask}){expand}") + else: + index_values.append(f"tl.zeros([1], {dtype}){expand}") + output_idx += 1 elif isinstance(k, int): index_values.append(repr(k)) elif isinstance(k, torch.SymInt): @@ -729,8 +764,16 @@ def is_supported( # TODO(jansel): support block_ptr with extra_mask return False input_sizes = collections.deque(fake_tensor.size()) - for k in index: - input_size = 1 if k is None else input_sizes.popleft() + for n, k in enumerate(index): + if k is None: + input_size = 1 + elif k is Ellipsis: + ellipsis_dims = _calculate_ellipsis_dims(index, n, fake_tensor.ndim) + for _ in range(ellipsis_dims): + input_sizes.popleft() + continue + else: + input_size = input_sizes.popleft() if isinstance(k, torch.SymInt): symbol = k._sympy_() origin = None @@ -780,9 +823,21 @@ def create( fake_value, reshaped_size=SubscriptIndexing.compute_shape(fake_value, index), ) - for k in index: + for n, k in enumerate(index): if k is None: pass # handled by reshaped_size + elif k is Ellipsis: + ellipsis_dims = _calculate_ellipsis_dims(index, n, fake_value.ndim) + env = CompileEnvironment.current() + for _ in range(ellipsis_dims): + size = fake_value.size(len(res.offsets)) + if size != 1: + rdim = env.allocate_reduction_dimension(size) + res.offsets.append(state.codegen.offset_var(rdim.block_id)) + res.block_shape.append(rdim.var) + else: + res.offsets.append("0") + res.block_shape.append(1) elif isinstance(k, int): res.offsets.append(repr(k)) res.block_shape.append(1) diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 615e4e80..97a042b2 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -433,6 +433,26 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]: inputs_consumed += 1 elif k.value is None: output_sizes.append(1) + elif k.value is Ellipsis: + # Count indices after ellipsis (excluding None) + remaining_keys = sum( + 1 + for key in keys[keys.index(k) + 1 :] + if not (isinstance(key, LiteralType) and key.value is None) + ) + ellipsis_dims = ( + self.fake_value.ndim - inputs_consumed - remaining_keys + ) + for _ in range(ellipsis_dims): + size = self.fake_value.size(inputs_consumed) + inputs_consumed += 1 + if self.origin.is_device(): + output_sizes.append(size) + elif size != 1: + rdim = env.allocate_reduction_dimension(size) + output_sizes.append(rdim.var) + else: + output_sizes.append(1) else: raise exc.InvalidIndexingType(k) elif isinstance(k, SymIntType): diff --git a/test/test_indexing.py b/test/test_indexing.py index c3b8b150..5cfc8a6d 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -759,9 +759,6 @@ def kernel( torch.testing.assert_close(src_result, expected_src) torch.testing.assert_close(dst_result, expected_dst) - @skipIfNormalMode( - "RankMismatch: Cannot assign a tensor of rank 2 to a buffer of rank 3" - ) def test_ellipsis_indexing(self): """Test both setter from scalar and getter for [..., i]"""