@@ -549,7 +549,21 @@ def create(
549
549
index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
550
550
output_idx += 1
551
551
elif isinstance (k , int ):
552
- index_values .append (repr (k ))
552
+ # Normalize negative indices
553
+ if k < 0 :
554
+ dim_idx = len (index_values )
555
+ dim_size = fake_value .size (dim_idx )
556
+ # Handle both concrete and symbolic dimension sizes
557
+ if isinstance (dim_size , int ):
558
+ normalized_k = k + dim_size
559
+ index_values .append (repr (normalized_k ))
560
+ else :
561
+ # For symbolic dimensions, we need to generate the proper expression
562
+ # The state.codegen is a GenerateAST instance which has device_function
563
+ sympy_expr = dim_size ._sympy_ () + k
564
+ index_values .append (f"({ state .codegen .device_function .user_sympy_expr (sympy_expr )} )" )
565
+ else :
566
+ index_values .append (repr (k ))
553
567
elif isinstance (k , torch .SymInt ):
554
568
symbol = k ._sympy_ ()
555
569
origin = None
@@ -839,7 +853,21 @@ def create(
839
853
res .offsets .append ("0" )
840
854
res .block_shape .append (1 )
841
855
elif isinstance (k , int ):
842
- res .offsets .append (repr (k ))
856
+ # Normalize negative indices
857
+ if k < 0 :
858
+ dim_idx = len (res .offsets )
859
+ dim_size = fake_value .size (dim_idx )
860
+ # Handle both concrete and symbolic dimension sizes
861
+ if isinstance (dim_size , int ):
862
+ normalized_k = k + dim_size
863
+ res .offsets .append (repr (normalized_k ))
864
+ else :
865
+ # For symbolic dimensions, we need to generate the proper expression
866
+ # The state.codegen is a GenerateAST instance which has device_function
867
+ sympy_expr = dim_size ._sympy_ () + k
868
+ res .offsets .append (f"({ state .codegen .device_function .user_sympy_expr (sympy_expr )} )" )
869
+ else :
870
+ res .offsets .append (repr (k ))
843
871
res .block_shape .append (1 )
844
872
elif isinstance (k , torch .SymInt ):
845
873
symbol = k ._sympy_ ()
0 commit comments