@@ -227,6 +227,9 @@ def valid_block_size(
227
227
for i , k in enumerate (subscript ):
228
228
if k is None :
229
229
continue
230
+ if k is Ellipsis :
231
+ # Ellipsis is not supported in tensor descriptor mode
232
+ return False
230
233
size , stride = size_stride .popleft ()
231
234
if isinstance (k , slice ):
232
235
# Slices with steps are not supported in tensor descriptor mode
@@ -465,9 +468,20 @@ def compute_shape(
465
468
input_size = collections .deque (tensor .size ())
466
469
output_size = []
467
470
env = CompileEnvironment .current ()
468
- for k in index :
471
+ for i , k in enumerate ( index ) :
469
472
if k is None :
470
473
output_size .append (1 )
474
+ elif k is Ellipsis :
475
+ # Ellipsis expands to consume all remaining dims except those after it
476
+ remaining_indices = len (index ) - i - 1
477
+ ellipsis_dims = len (input_size ) - remaining_indices
478
+ for _ in range (ellipsis_dims ):
479
+ size = input_size .popleft ()
480
+ if size != 1 :
481
+ rdim = env .allocate_reduction_dimension (size )
482
+ output_size .append (rdim .var )
483
+ else :
484
+ output_size .append (1 )
471
485
elif isinstance (k , int ):
472
486
input_size .popleft ()
473
487
elif isinstance (k , torch .SymInt ):
@@ -517,6 +531,23 @@ def create(
517
531
for n , k in enumerate (index ):
518
532
if k is None :
519
533
output_idx += 1
534
+ elif k is Ellipsis :
535
+ # Ellipsis expands to handle remaining dimensions
536
+ remaining_indices = len (index ) - n - 1
537
+ ellipsis_dims = fake_value .ndim - len (index_values ) - remaining_indices
538
+ for dim_offset in range (ellipsis_dims ):
539
+ expand = tile_strategy .expand_str (output_size , output_idx )
540
+ size = fake_value .size (len (index_values ))
541
+ if size != 1 :
542
+ rdim = env .allocate_reduction_dimension (size )
543
+ block_idx = rdim .block_id
544
+ index_var = state .codegen .index_var (block_idx )
545
+ index_values .append (f"({ index_var } ){ expand } " )
546
+ if mask := state .codegen .mask_var (block_idx ):
547
+ mask_values .setdefault (f"({ mask } ){ expand } " )
548
+ else :
549
+ index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
550
+ output_idx += 1
520
551
elif isinstance (k , int ):
521
552
index_values .append (repr (k ))
522
553
elif isinstance (k , torch .SymInt ):
@@ -729,8 +760,18 @@ def is_supported(
729
760
# TODO(jansel): support block_ptr with extra_mask
730
761
return False
731
762
input_sizes = collections .deque (fake_tensor .size ())
732
- for k in index :
733
- input_size = 1 if k is None else input_sizes .popleft ()
763
+ for n , k in enumerate (index ):
764
+ if k is None :
765
+ input_size = 1
766
+ elif k is Ellipsis :
767
+ # Skip appropriate number of dimensions for ellipsis
768
+ remaining_indices = len (index ) - n - 1
769
+ ellipsis_dims = len (input_sizes ) - remaining_indices
770
+ for _ in range (ellipsis_dims ):
771
+ input_sizes .popleft ()
772
+ continue
773
+ else :
774
+ input_size = input_sizes .popleft ()
734
775
if isinstance (k , torch .SymInt ):
735
776
symbol = k ._sympy_ ()
736
777
origin = None
@@ -780,9 +821,23 @@ def create(
780
821
fake_value ,
781
822
reshaped_size = SubscriptIndexing .compute_shape (fake_value , index ),
782
823
)
783
- for k in index :
824
+ for n , k in enumerate ( index ) :
784
825
if k is None :
785
826
pass # handled by reshaped_size
827
+ elif k is Ellipsis :
828
+ # Ellipsis expands to handle remaining dimensions
829
+ remaining_indices = len (index ) - n - 1
830
+ ellipsis_dims = fake_value .ndim - len (res .offsets ) - remaining_indices
831
+ for _ in range (ellipsis_dims ):
832
+ size = fake_value .size (len (res .offsets ))
833
+ if size != 1 :
834
+ env = CompileEnvironment .current ()
835
+ rdim = env .allocate_reduction_dimension (size )
836
+ res .offsets .append (state .codegen .offset_var (rdim .block_id ))
837
+ res .block_shape .append (rdim .var )
838
+ else :
839
+ res .offsets .append ("0" )
840
+ res .block_shape .append (1 )
786
841
elif isinstance (k , int ):
787
842
res .offsets .append (repr (k ))
788
843
res .block_shape .append (1 )
0 commit comments