@@ -7299,18 +7299,21 @@ def aten_repeat_interleave(
72997299 )
73007300
73017301 # Find which original index each output position corresponds to
7302- # We need to find the first cumsum position > each output position
7303- # This is equivalent to a searchsorted operation
7302+ # Use the same approach as in self_tensor version
7303+ num_elements = op . Size ( repeats_int64 )
73047304
7305- # Expand dimensions for broadcasting
7306- cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # Shape: [1, len(repeats)]
7307- output_range_expanded = op .Unsqueeze (output_range , [1 ]) # Shape: [total_size, 1]
7305+ cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # [1, num_elements]
7306+ output_expanded = op .Unsqueeze (output_range , [1 ]) # [total_size, 1]
73087307
7309- # Find positions where output_range < cumsum
7310- mask = op .Less ( output_range_expanded , cumsum_expanded ) # Shape: [total_size, len(repeats) ]
7308+ # Use LessOrEqual to find cumsum <= output_pos
7309+ mask = op .LessOrEqual ( cumsum_expanded , output_expanded ) # [total_size, num_elements ]
73117310
7312- # For each row, find the first True position (argmax will do this since True=1, False=0)
7313- result_indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
7311+ # Sum to get the count of cumsum values <= each position
7312+ result_indices = op .ReduceSum (op .Cast (mask , to = INT64 .dtype ), axes = [1 ], keepdims = False )
7313+
7314+ # Clamp to valid range [0, num_elements-1]
7315+ max_index = op .Sub (num_elements , op .Constant (value_ints = [1 ]))
7316+ result_indices = op .Clip (result_indices , op .Constant (value_ints = [0 ]), max_index )
73147317
73157318 return result_indices
73167319
@@ -7325,64 +7328,85 @@ def aten_repeat_interleave_self_tensor(
73257328 """repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor"""
73267329
73277330 if dim is None :
7328- # Flatten the tensor first, then repeat elements
7331+ # Flatten the tensor first
73297332 self_flat = op .Reshape (self , [- 1 ])
73307333
73317334 # Convert repeats to int64 for ONNX compatibility
73327335 repeats_int64 = op .Cast (repeats , to = INT64 .dtype )
73337336
7334- # Get cumulative sum of repeats to find the boundaries
7337+ # Create a simple approach: for each element, tile it according to its repeat count
7338+ # Then concatenate all results
7339+
7340+ # Get the length of repeats (number of elements)
7341+ num_elements = op .Size (repeats_int64 )
7342+
7343+ # We'll build the result by processing each element
7344+ # Since we can't use loops, we need a different approach
7345+
7346+ # Alternative: create indices by "unrolling" the repeats
7347+ # Build a tensor where position i contains the element index for output position i
7348+
7349+ # First, get cumulative sum to know boundaries
73357350 cumsum = op .CumSum (repeats_int64 , axis = 0 )
73367351 total_size = op .Gather (cumsum , op .Constant (value_ints = [- 1 ]), axis = 0 )
73377352
7338- # Create output tensor indices
7339- output_range = op .Range (
7353+ # Create the indices tensor directly using a different algorithm
7354+ # We'll create a "mask" approach but compute indices differently
7355+
7356+ # For each possible output position, compute which input element it corresponds to
7357+ # by comparing against cumulative sums
7358+
7359+ # Create range for all output positions
7360+ output_positions = op .Range (
73407361 op .Constant (value_ints = [0 ]), total_size , op .Constant (value_ints = [1 ])
73417362 )
73427363
7343- # Find which original index each output position corresponds to
7344- cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # Shape: [1, len(repeats)]
7345- output_range_expanded = op .Unsqueeze (output_range , [1 ]) # Shape: [total_size, 1]
7364+ # For each output position, we need to find which element it belongs to
7365+ # Instead of ArgMax, we can use: sum(cumsum <= output_pos)
7366+ # This gives us the number of elements whose cumsum is <= output_pos
7367+ # Which means output_pos belongs to the next element
7368+
7369+ # Expand for broadcasting
7370+ cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # [1, num_elements]
7371+ positions_expanded = op .Unsqueeze (output_positions , [1 ]) # [total_size, 1]
7372+
7373+ # Compare: cumsum <= output_pos (note: LessOrEqual instead of Less)
7374+ mask = op .LessOrEqual (
7375+ cumsum_expanded , positions_expanded
7376+ ) # [total_size, num_elements]
73467377
7347- # Find positions where output_range < cumsum
7348- mask = op .Less (
7349- output_range_expanded , cumsum_expanded
7350- ) # Shape: [total_size, len(repeats)]
7378+ # Sum to get the count of cumsum values <= each position
7379+ indices = op .ReduceSum (op .Cast (mask , to = INT64 .dtype ), axes = [1 ], keepdims = False )
73517380
7352- # For each row, find the first True position
7353- indices = op .ArgMax (op .Cast (mask , to = INT64 .dtype ), axis = 1 , keepdims = False )
7381+ # Clamp to valid range [0, num_elements-1]
7382+ max_index = op .Sub (num_elements , op .Constant (value_ints = [1 ]))
7383+ indices = op .Clip (indices , op .Constant (value_ints = [0 ]), max_index )
73547384
73557385 # Gather elements from the flattened tensor
73567386 result = op .Gather (self_flat , indices , axis = 0 )
73577387 return result
73587388
73597389 else :
7360- # Repeat along specific dimension
7361- # Convert repeats to int64 for ONNX compatibility
7390+ # Repeat along specific dimension using the same approach
73627391 repeats_int64 = op .Cast (repeats , to = INT64 .dtype )
73637392
7364- # Get cumulative sum of repeats to find the boundaries
7393+ num_elements = op . Size ( repeats_int64 )
73657394 cumsum = op .CumSum (repeats_int64 , axis = 0 )
73667395 total_size = op .Gather (cumsum , op .Constant (value_ints = [- 1 ]), axis = 0 )
73677396
7368- # Create output tensor indices for the specified dimension
7369- output_range = op .Range (
7397+ output_positions = op .Range (
73707398 op .Constant (value_ints = [0 ]), total_size , op .Constant (value_ints = [1 ])
73717399 )
73727400
7373- # Find which original index each output position corresponds to
7374- cumsum_expanded = op .Unsqueeze (cumsum , [0 ]) # Shape: [1, len(repeats)]
7375- output_range_expanded = op .Unsqueeze (output_range , [1 ]) # Shape: [total_size, 1]
7401+ cumsum_expanded = op .Unsqueeze (cumsum , [0 ])
7402+ positions_expanded = op .Unsqueeze (output_positions , [1 ])
73767403
7377- # Find positions where output_range < cumsum
7378- mask = op .Less (
7379- output_range_expanded , cumsum_expanded
7380- ) # Shape: [total_size, len(repeats)]
7404+ mask = op .LessOrEqual (cumsum_expanded , positions_expanded )
7405+ indices = op .ReduceSum (op .Cast (mask , to = INT64 .dtype ), axes = [1 ], keepdims = False )
73817406
7382- # For each row, find the first True position
7383- indices = op .ArgMax ( op .Cast ( mask , to = INT64 . dtype ), axis = 1 , keepdims = False )
7407+ max_index = op . Sub ( num_elements , op . Constant ( value_ints = [ 1 ]))
7408+ indices = op .Clip ( indices , op .Constant ( value_ints = [ 0 ] ), max_index )
73847409
7385- # Gather elements along the specified dimension
73867410 result = op .Gather (self , indices , axis = dim )
73877411 return result
73887412
@@ -7408,41 +7432,55 @@ def aten_repeat_interleave_self_int(
74087432 return result
74097433
74107434 else :
7411- # Repeat along specific dimension
7412- # Apply Tile directly to the tensor instead of creating indices (more efficient)
7435+ # Repeat along specific dimension using simpler approach
7436+ # First, get the shape of the input tensor
7437+ original_shape = op .Shape (self )
74137438
7414- # Expand tensor by adding dimension after target dim
7439+ # Use the approach similar to aten_repeat but for a single dimension
7440+ # Add a new dimension after the target dimension
74157441 self_expanded = op .Unsqueeze (self , [dim + 1 ])
74167442
7417- # Get original shape to build tile pattern dynamically
7418- original_shape = op .Shape (self )
7419- num_dims = op .Size (original_shape )
7420-
7421- # Build tile pattern: all 1s except position dim+1 which is 'repeats'
7422- # Use ConstantOfShape to create array of 1s, then update specific position
7423- ones_pattern = op .ConstantOfShape (
7424- op .Add (num_dims , op .Constant (value_ints = [1 ])), # +1 for the new dimension
7443+ # Get the rank and build tile pattern
7444+ rank = op .Size (original_shape )
7445+ ones_before = op .ConstantOfShape (
7446+ op .Reshape (
7447+ op .Add (op .Constant (value_ints = [dim ]), op .Constant (value_ints = [1 ])), [1 ]
7448+ ),
7449+ op .Constant (value_ints = [1 ]),
7450+ )
7451+ repeat_val = op .Constant (value_ints = [repeats ])
7452+ ones_after = op .ConstantOfShape (
7453+ op .Reshape (
7454+ op .Sub (
7455+ rank , op .Add (op .Constant (value_ints = [dim ]), op .Constant (value_ints = [1 ]))
7456+ ),
7457+ [1 ],
7458+ ),
74257459 op .Constant (value_ints = [1 ]),
74267460 )
74277461
7428- # Create indices and updates for ScatterND to set position dim+1 to 'repeats'
7429- update_indices = op .Reshape (op .Constant (value_ints = [dim + 1 ]), [1 , 1 ])
7430- update_values = op .Constant (value_ints = [repeats ])
7431-
7432- tile_pattern = op .ScatterND (ones_pattern , update_indices , update_values )
7462+ # Concatenate to build tile pattern: [1, 1, ..., 1, repeats, 1, ..., 1]
7463+ tile_pattern = op .Concat (ones_before , repeat_val , ones_after , axis = 0 )
74337464
74347465 # Tile the expanded tensor
74357466 tiled = op .Tile (self_expanded , tile_pattern )
74367467
7437- # Reshape to merge the two dimensions
7438- # Calculate new shape: original shape with target dimension multiplied by repeats
7468+ # Reshape to merge the repeated dimension
7469+ # Calculate new shape
74397470 target_dim_size = op .Gather (original_shape , op .Constant (value_ints = [dim ]))
74407471 new_target_size = op .Mul (target_dim_size , op .Constant (value_ints = [repeats ]))
74417472
7442- # Create new shape by updating the target dimension
7443- update_shape_indices = op .Reshape (op .Constant (value_ints = [dim ]), [1 , 1 ])
7444- new_shape = op .ScatterND (
7445- original_shape , update_shape_indices , op .Reshape (new_target_size , [1 ])
7473+ # Build new shape by concatenating parts
7474+ shape_before = op .Slice (
7475+ original_shape , op .Constant (value_ints = [0 ]), op .Constant (value_ints = [dim ])
7476+ )
7477+ shape_after = op .Slice (
7478+ original_shape ,
7479+ op .Add (op .Constant (value_ints = [dim ]), op .Constant (value_ints = [1 ])),
7480+ op .Constant (value_ints = [2147483647 ]),
7481+ )
7482+ new_shape = op .Concat (
7483+ shape_before , op .Reshape (new_target_size , [1 ]), shape_after , axis = 0
74467484 )
74477485
74487486 result = op .Reshape (tiled , new_shape )
0 commit comments