@@ -7331,59 +7331,58 @@ def aten_repeat_interleave_self_tensor(
73317331 # Convert repeats to int64 for ONNX compatibility
73327332 repeats_int64 = op .Cast (repeats , to = INT64 .dtype )
73337333
7334- # Get cumulative sum of repeats to find the boundaries
7334+ # Use an approach similar to self_int but adapted for variable repeats
7335+ # The key optimization: avoid creating large intermediate index tensors
7336+
7337+ # Get cumulative sum to determine output positions
73357338 cumsum = op .CumSum (repeats_int64 , axis = 0 )
73367339 total_size = op .Gather (cumsum , op .Constant (value_ints = [- 1 ]), axis = 0 )
73377340
7338- # Create output tensor indices
7341+ # Create output indices
73397342 output_range = op .Range (
73407343 op .Constant (value_ints = [0 ]), total_size , op .Constant (value_ints = [1 ])
73417344 )
73427345
7343- # Find which original index each output position corresponds to
7344- cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # Shape: [1, len(repeats)]
7345- output_range_expanded = op .Unsqueeze (output_range , [1 ]) # Shape: [total_size, 1]
7346-
7347- # Find positions where output_range < cumsum
7348- mask = op .Less (
7349- output_range_expanded , cumsum_expanded
7350- ) # Shape: [total_size, len(repeats)]
7346+ # More efficient searchsorted: find input index for each output position
7347+ # Broadcast to find positions where output_idx < cumsum_val
7348+ cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # [1, n_elements]
7349+ output_expanded = op .Unsqueeze (output_range , [1 ]) # [total_size, 1]
73517350
7352- # For each row, find the first True position
7353- indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
7351+ # Find first position where output_idx < cumsum_val
7352+ mask = op .Less (output_expanded , cumsum_expanded ) # [total_size, n_elements]
7353+ input_indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
73547354
7355- # Gather elements from the flattened tensor
7356- result = op .Gather (self_flat , indices , axis = 0 )
7355+ # Gather the actual values
7356+ result = op .Gather (self_flat , input_indices , axis = 0 )
73577357 return result
73587358
73597359 else :
7360- # Repeat along specific dimension
7360+ # Repeat along specific dimension - use approach similar to optimized self_int
73617361 # Convert repeats to int64 for ONNX compatibility
73627362 repeats_int64 = op .Cast (repeats , to = INT64 .dtype )
73637363
7364- # Get cumulative sum of repeats to find the boundaries
7364+ # Use a more efficient approach similar to self_int optimization
7365+ # The challenge is that we have variable repeat counts per slice
7366+
7367+ # Get cumulative sum to find boundaries (this part is necessary for variable repeats)
73657368 cumsum = op .CumSum (repeats_int64 , axis = 0 )
73667369 total_size = op .Gather (cumsum , op .Constant (value_ints = [- 1 ]), axis = 0 )
73677370
7368- # Create output tensor indices for the specified dimension
7371+ # Create output indices for the dimension
73697372 output_range = op .Range (
73707373 op .Constant (value_ints = [0 ]), total_size , op .Constant (value_ints = [1 ])
73717374 )
73727375
7373- # Find which original index each output position corresponds to
7374- cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # Shape: [1, len(repeats) ]
7375- output_range_expanded = op .Unsqueeze (output_range , [1 ]) # Shape: [total_size, 1]
7376+ # Efficient mapping from output positions to input indices
7377+ cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # [1, n_slices ]
7378+ output_expanded = op .Unsqueeze (output_range , [1 ]) # [total_size, 1]
73767379
7377- # Find positions where output_range < cumsum
7378- mask = op .Less (
7379- output_range_expanded , cumsum_expanded
7380- ) # Shape: [total_size, len(repeats)]
7380+ # Find input slice index for each output position
7381+ mask = op .Less (output_expanded , cumsum_expanded ) # [total_size, n_slices]
7382+ input_indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
73817383
7382- # For each row, find the first True position
7383- indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
7384-
7385- # Gather elements along the specified dimension
7386- result = op .Gather (self , indices , axis = dim )
7384+ # Gather slices along the specified dimension
7385+ result = op .Gather (self , input_indices , axis = dim )
73877386 return result
73887387
73897388
@@ -7399,7 +7398,7 @@ def aten_repeat_interleave_self_int(
73997398 if dim is None :
74007399 # Flatten the tensor first, then repeat each element 'repeats' times
74017400 self_flat = op .Reshape (self , [- 1 ])
7402-
7401+
74037402 # Add a new dimension and tile to repeat each element
74047403 self_expanded = op .Unsqueeze (self_flat , [1 ]) # Shape: [num_elements, 1]
74057404 repeat_pattern = op .Constant (value_ints = [1 , repeats ])
@@ -7410,40 +7409,41 @@ def aten_repeat_interleave_self_int(
74107409 else :
74117410 # Repeat along specific dimension
74127411 # Apply Tile directly to the tensor instead of creating indices (more efficient)
7413-
7412+
74147413 # Expand tensor by adding dimension after target dim
74157414 self_expanded = op .Unsqueeze (self , [dim + 1 ])
7416-
7415+
74177416 # Get original shape to build tile pattern dynamically
74187417 original_shape = op .Shape (self )
74197418 num_dims = op .Size (original_shape )
7420-
7419+
74217420 # Build tile pattern: all 1s except position dim+1 which is 'repeats'
74227421 # Use ConstantOfShape to create array of 1s, then update specific position
74237422 ones_pattern = op .ConstantOfShape (
74247423 op .Add (num_dims , op .Constant (value_ints = [1 ])), # +1 for the new dimension
7425- op .Constant (value_ints = [1 ])
7424+ op .Constant (value_ints = [1 ]),
74267425 )
7427-
7426+
74287427 # Create indices and updates for ScatterND to set position dim+1 to 'repeats'
74297428 update_indices = op .Reshape (op .Constant (value_ints = [dim + 1 ]), [1 , 1 ])
74307429 update_values = op .Constant (value_ints = [repeats ])
7431-
7430+
74327431 tile_pattern = op .ScatterND (ones_pattern , update_indices , update_values )
7433-
7432+
74347433 # Tile the expanded tensor
74357434 tiled = op .Tile (self_expanded , tile_pattern )
7436-
7435+
74377436 # Reshape to merge the two dimensions
74387437 # Calculate new shape: original shape with target dimension multiplied by repeats
74397438 target_dim_size = op .Gather (original_shape , op .Constant (value_ints = [dim ]))
74407439 new_target_size = op .Mul (target_dim_size , op .Constant (value_ints = [repeats ]))
7441-
7440+
74427441 # Create new shape by updating the target dimension
74437442 update_shape_indices = op .Reshape (op .Constant (value_ints = [dim ]), [1 , 1 ])
7444- new_shape = op .ScatterND (original_shape , update_shape_indices ,
7445- op .Reshape (new_target_size , [1 ]))
7446-
7443+ new_shape = op .ScatterND (
7444+ original_shape , update_shape_indices , op .Reshape (new_target_size , [1 ])
7445+ )
7446+
74477447 result = op .Reshape (tiled , new_shape )
74487448 return result
74497449
0 commit comments