@@ -227,6 +227,9 @@ def valid_block_size(
227
227
for i , k in enumerate (subscript ):
228
228
if k is None :
229
229
continue
230
+ if k is Ellipsis :
231
+ # Ellipsis is not supported in tensor descriptor mode
232
+ return False
230
233
size , stride = size_stride .popleft ()
231
234
if isinstance (k , slice ):
232
235
# Slices with steps are not supported in tensor descriptor mode
@@ -447,6 +450,14 @@ def codegen_store(
447
450
)
448
451
449
452
453
+ def _calculate_ellipsis_dims (
454
+ index : list [object ], current_index : int , total_dims : int
455
+ ) -> int :
456
+ """Calculate how many dimensions an ellipsis should expand to."""
457
+ remaining_indices = len (index ) - current_index - 1
458
+ return total_dims - current_index - remaining_indices
459
+
460
+
450
461
class SubscriptIndexing (NamedTuple ):
451
462
index_expr : ast .AST
452
463
mask_expr : ast .AST
@@ -465,9 +476,19 @@ def compute_shape(
465
476
input_size = collections .deque (tensor .size ())
466
477
output_size = []
467
478
env = CompileEnvironment .current ()
468
- for k in index :
479
+ for i , k in enumerate ( index ) :
469
480
if k is None :
470
481
output_size .append (1 )
482
+ elif k is Ellipsis :
483
+ # Ellipsis expands to consume all remaining dims except those after it
484
+ ellipsis_dims = _calculate_ellipsis_dims (index , i , len (tensor .size ()))
485
+ for _ in range (ellipsis_dims ):
486
+ size = input_size .popleft ()
487
+ if size != 1 :
488
+ rdim = env .allocate_reduction_dimension (size )
489
+ output_size .append (rdim .var )
490
+ else :
491
+ output_size .append (1 )
471
492
elif isinstance (k , int ):
472
493
input_size .popleft ()
473
494
elif isinstance (k , torch .SymInt ):
@@ -517,6 +538,22 @@ def create(
517
538
for n , k in enumerate (index ):
518
539
if k is None :
519
540
output_idx += 1
541
+ elif k is Ellipsis :
542
+ # Ellipsis expands to handle remaining dimensions
543
+ ellipsis_dims = _calculate_ellipsis_dims (index , n , fake_value .ndim )
544
+ for _ in range (ellipsis_dims ):
545
+ expand = tile_strategy .expand_str (output_size , output_idx )
546
+ size = fake_value .size (len (index_values ))
547
+ if size != 1 :
548
+ rdim = env .allocate_reduction_dimension (size )
549
+ block_idx = rdim .block_id
550
+ index_var = state .codegen .index_var (block_idx )
551
+ index_values .append (f"({ index_var } ){ expand } " )
552
+ if mask := state .codegen .mask_var (block_idx ):
553
+ mask_values .setdefault (f"({ mask } ){ expand } " )
554
+ else :
555
+ index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
556
+ output_idx += 1
520
557
elif isinstance (k , int ):
521
558
index_values .append (repr (k ))
522
559
elif isinstance (k , torch .SymInt ):
@@ -729,8 +766,17 @@ def is_supported(
729
766
# TODO(jansel): support block_ptr with extra_mask
730
767
return False
731
768
input_sizes = collections .deque (fake_tensor .size ())
732
- for k in index :
733
- input_size = 1 if k is None else input_sizes .popleft ()
769
+ for n , k in enumerate (index ):
770
+ if k is None :
771
+ input_size = 1
772
+ elif k is Ellipsis :
773
+ # Skip appropriate number of dimensions for ellipsis
774
+ ellipsis_dims = _calculate_ellipsis_dims (index , n , fake_tensor .ndim )
775
+ for _ in range (ellipsis_dims ):
776
+ input_sizes .popleft ()
777
+ continue
778
+ else :
779
+ input_size = input_sizes .popleft ()
734
780
if isinstance (k , torch .SymInt ):
735
781
symbol = k ._sympy_ ()
736
782
origin = None
@@ -780,9 +826,22 @@ def create(
780
826
fake_value ,
781
827
reshaped_size = SubscriptIndexing .compute_shape (fake_value , index ),
782
828
)
783
- for k in index :
829
+ for n , k in enumerate ( index ) :
784
830
if k is None :
785
831
pass # handled by reshaped_size
832
+ elif k is Ellipsis :
833
+ # Ellipsis expands to handle remaining dimensions
834
+ ellipsis_dims = _calculate_ellipsis_dims (index , n , fake_value .ndim )
835
+ env = CompileEnvironment .current ()
836
+ for _ in range (ellipsis_dims ):
837
+ size = fake_value .size (len (res .offsets ))
838
+ if size != 1 :
839
+ rdim = env .allocate_reduction_dimension (size )
840
+ res .offsets .append (state .codegen .offset_var (rdim .block_id ))
841
+ res .block_shape .append (rdim .var )
842
+ else :
843
+ res .offsets .append ("0" )
844
+ res .block_shape .append (1 )
786
845
elif isinstance (k , int ):
787
846
res .offsets .append (repr (k ))
788
847
res .block_shape .append (1 )
0 commit comments