@@ -7331,58 +7331,59 @@ def aten_repeat_interleave_self_tensor(
73317331 # Convert repeats to int64 for ONNX compatibility
73327332 repeats_int64 = op .Cast (repeats , to = INT64 .dtype )
73337333
7334- # Use an approach similar to self_int but adapted for variable repeats
7335- # The key optimization: avoid creating large intermediate index tensors
7336-
7337- # Get cumulative sum to determine output positions
7334+ # Get cumulative sum of repeats to find the boundaries
73387335 cumsum = op .CumSum (repeats_int64 , axis = 0 )
73397336 total_size = op .Gather (cumsum , op .Constant (value_ints = [- 1 ]), axis = 0 )
73407337
7341- # Create output indices
7338+ # Create output tensor indices
73427339 output_range = op .Range (
73437340 op .Constant (value_ints = [0 ]), total_size , op .Constant (value_ints = [1 ])
73447341 )
73457342
7346- # More efficient searchsorted: find input index for each output position
7347- # Broadcast to find positions where output_idx < cumsum_val
7348- cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # [1, n_elements]
7349- output_expanded = op .Unsqueeze (output_range , [1 ]) # [total_size, 1]
7343+ # Find which original index each output position corresponds to
7344+ cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # Shape: [1, len(repeats)]
7345+ output_range_expanded = op .Unsqueeze (output_range , [1 ]) # Shape: [total_size, 1]
7346+
7347+ # Find positions where output_range < cumsum
7348+ mask = op .Less (
7349+ output_range_expanded , cumsum_expanded
7350+ ) # Shape: [total_size, len(repeats)]
73507351
7351- # Find first position where output_idx < cumsum_val
7352- mask = op .Less (output_expanded , cumsum_expanded ) # [total_size, n_elements]
7353- input_indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
7352+ # For each row, find the first True position
7353+ indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
73547354
7355- # Gather the actual values
7356- result = op .Gather (self_flat , input_indices , axis = 0 )
7355+ # Gather elements from the flattened tensor
7356+ result = op .Gather (self_flat , indices , axis = 0 )
73577357 return result
73587358
73597359 else :
7360- # Repeat along specific dimension - use approach similar to optimized self_int
7360+ # Repeat along specific dimension
73617361 # Convert repeats to int64 for ONNX compatibility
73627362 repeats_int64 = op .Cast (repeats , to = INT64 .dtype )
73637363
7364- # Use a more efficient approach similar to self_int optimization
7365- # The challenge is that we have variable repeat counts per slice
7366-
7367- # Get cumulative sum to find boundaries (this part is necessary for variable repeats)
7364+ # Get cumulative sum of repeats to find the boundaries
73687365 cumsum = op .CumSum (repeats_int64 , axis = 0 )
73697366 total_size = op .Gather (cumsum , op .Constant (value_ints = [- 1 ]), axis = 0 )
73707367
7371- # Create output indices for the dimension
7368+ # Create output tensor indices for the specified dimension
73727369 output_range = op .Range (
73737370 op .Constant (value_ints = [0 ]), total_size , op .Constant (value_ints = [1 ])
73747371 )
73757372
7376- # Efficient mapping from output positions to input indices
7377- cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # [1, n_slices]
7378- output_expanded = op .Unsqueeze (output_range , [1 ]) # [total_size, 1]
7373+ # Find which original index each output position corresponds to
7374+ cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # Shape: [1, len(repeats)]
7375+ output_range_expanded = op .Unsqueeze (output_range , [1 ]) # Shape: [total_size, 1]
7376+
7377+ # Find positions where output_range < cumsum
7378+ mask = op .Less (
7379+ output_range_expanded , cumsum_expanded
7380+ ) # Shape: [total_size, len(repeats)]
73797381
7380- # Find input slice index for each output position
7381- mask = op .Less (output_expanded , cumsum_expanded ) # [total_size, n_slices]
7382- input_indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
7382+ # For each row, find the first True position
7383+ indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
73837384
7384- # Gather slices along the specified dimension
7385- result = op .Gather (self , input_indices , axis = dim )
7385+ # Gather elements along the specified dimension
7386+ result = op .Gather (self , indices , axis = dim )
73867387 return result
73877388
73887389
0 commit comments