Skip to content

Commit 5409040

Browse files
Copilotxadupre
andcommitted
Remove ArgMax and ScatterND operations from repeat_interleave implementations
Co-authored-by: xadupre <[email protected]>
1 parent a22cbfb commit 5409040

File tree

1 file changed

+97
-59
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+97
-59
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 97 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7299,18 +7299,21 @@ def aten_repeat_interleave(
72997299
)
73007300

73017301
# Find which original index each output position corresponds to
7302-
# We need to find the first cumsum position > each output position
7303-
# This is equivalent to a searchsorted operation
7302+
# Use the same approach as in self_tensor version
7303+
num_elements = op.Size(repeats_int64)
73047304

7305-
# Expand dimensions for broadcasting
7306-
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # Shape: [1, len(repeats)]
7307-
output_range_expanded = op.Unsqueeze(output_range, [1]) # Shape: [total_size, 1]
7305+
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # [1, num_elements]
7306+
output_expanded = op.Unsqueeze(output_range, [1]) # [total_size, 1]
73087307

7309-
# Find positions where output_range < cumsum
7310-
mask = op.Less(output_range_expanded, cumsum_expanded) # Shape: [total_size, len(repeats)]
7308+
# Use LessOrEqual to find cumsum <= output_pos
7309+
mask = op.LessOrEqual(cumsum_expanded, output_expanded) # [total_size, num_elements]
73117310

7312-
# For each row, find the first True position (argmax will do this since True=1, False=0)
7313-
result_indices = op.ArgMax(op.Cast(mask, to=INT64.dtype), axis=1, keepdims=False)
7311+
# Sum to get the count of cumsum values <= each position
7312+
result_indices = op.ReduceSum(op.Cast(mask, to=INT64.dtype), axes=[1], keepdims=False)
7313+
7314+
# Clamp to valid range [0, num_elements-1]
7315+
max_index = op.Sub(num_elements, op.Constant(value_ints=[1]))
7316+
result_indices = op.Clip(result_indices, op.Constant(value_ints=[0]), max_index)
73147317

73157318
return result_indices
73167319

@@ -7325,64 +7328,85 @@ def aten_repeat_interleave_self_tensor(
73257328
"""repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor"""
73267329

73277330
if dim is None:
7328-
# Flatten the tensor first, then repeat elements
7331+
# Flatten the tensor first
73297332
self_flat = op.Reshape(self, [-1])
73307333

73317334
# Convert repeats to int64 for ONNX compatibility
73327335
repeats_int64 = op.Cast(repeats, to=INT64.dtype)
73337336

7334-
# Get cumulative sum of repeats to find the boundaries
7337+
# Create a simple approach: for each element, tile it according to its repeat count
7338+
# Then concatenate all results
7339+
7340+
# Get the length of repeats (number of elements)
7341+
num_elements = op.Size(repeats_int64)
7342+
7343+
# We'll build the result by processing each element
7344+
# Since we can't use loops, we need a different approach
7345+
7346+
# Alternative: create indices by "unrolling" the repeats
7347+
# Build a tensor where position i contains the element index for output position i
7348+
7349+
# First, get cumulative sum to know boundaries
73357350
cumsum = op.CumSum(repeats_int64, axis=0)
73367351
total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0)
73377352

7338-
# Create output tensor indices
7339-
output_range = op.Range(
7353+
# Create the indices tensor directly using a different algorithm
7354+
# We'll create a "mask" approach but compute indices differently
7355+
7356+
# For each possible output position, compute which input element it corresponds to
7357+
# by comparing against cumulative sums
7358+
7359+
# Create range for all output positions
7360+
output_positions = op.Range(
73407361
op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1])
73417362
)
73427363

7343-
# Find which original index each output position corresponds to
7344-
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # Shape: [1, len(repeats)]
7345-
output_range_expanded = op.Unsqueeze(output_range, [1]) # Shape: [total_size, 1]
7364+
# For each output position, we need to find which element it belongs to
7365+
# Instead of ArgMax, we can use: sum(cumsum <= output_pos)
7366+
# This gives us the number of elements whose cumsum is <= output_pos
7367+
# Which means output_pos belongs to the next element
7368+
7369+
# Expand for broadcasting
7370+
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # [1, num_elements]
7371+
positions_expanded = op.Unsqueeze(output_positions, [1]) # [total_size, 1]
7372+
7373+
# Compare: cumsum <= output_pos (note: LessOrEqual instead of Less)
7374+
mask = op.LessOrEqual(
7375+
cumsum_expanded, positions_expanded
7376+
) # [total_size, num_elements]
73467377

7347-
# Find positions where output_range < cumsum
7348-
mask = op.Less(
7349-
output_range_expanded, cumsum_expanded
7350-
) # Shape: [total_size, len(repeats)]
7378+
# Sum to get the count of cumsum values <= each position
7379+
indices = op.ReduceSum(op.Cast(mask, to=INT64.dtype), axes=[1], keepdims=False)
73517380

7352-
# For each row, find the first True position
7353-
indices = op.ArgMax(op.Cast(mask, to=INT64.dtype), axis=1, keepdims=False)
7381+
# Clamp to valid range [0, num_elements-1]
7382+
max_index = op.Sub(num_elements, op.Constant(value_ints=[1]))
7383+
indices = op.Clip(indices, op.Constant(value_ints=[0]), max_index)
73547384

73557385
# Gather elements from the flattened tensor
73567386
result = op.Gather(self_flat, indices, axis=0)
73577387
return result
73587388

73597389
else:
7360-
# Repeat along specific dimension
7361-
# Convert repeats to int64 for ONNX compatibility
7390+
# Repeat along specific dimension using the same approach
73627391
repeats_int64 = op.Cast(repeats, to=INT64.dtype)
73637392

7364-
# Get cumulative sum of repeats to find the boundaries
7393+
num_elements = op.Size(repeats_int64)
73657394
cumsum = op.CumSum(repeats_int64, axis=0)
73667395
total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0)
73677396

7368-
# Create output tensor indices for the specified dimension
7369-
output_range = op.Range(
7397+
output_positions = op.Range(
73707398
op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1])
73717399
)
73727400

7373-
# Find which original index each output position corresponds to
7374-
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # Shape: [1, len(repeats)]
7375-
output_range_expanded = op.Unsqueeze(output_range, [1]) # Shape: [total_size, 1]
7401+
cumsum_expanded = op.Unsqueeze(cumsum, [0])
7402+
positions_expanded = op.Unsqueeze(output_positions, [1])
73767403

7377-
# Find positions where output_range < cumsum
7378-
mask = op.Less(
7379-
output_range_expanded, cumsum_expanded
7380-
) # Shape: [total_size, len(repeats)]
7404+
mask = op.LessOrEqual(cumsum_expanded, positions_expanded)
7405+
indices = op.ReduceSum(op.Cast(mask, to=INT64.dtype), axes=[1], keepdims=False)
73817406

7382-
# For each row, find the first True position
7383-
indices = op.ArgMax(op.Cast(mask, to=INT64.dtype), axis=1, keepdims=False)
7407+
max_index = op.Sub(num_elements, op.Constant(value_ints=[1]))
7408+
indices = op.Clip(indices, op.Constant(value_ints=[0]), max_index)
73847409

7385-
# Gather elements along the specified dimension
73867410
result = op.Gather(self, indices, axis=dim)
73877411
return result
73887412

@@ -7408,41 +7432,55 @@ def aten_repeat_interleave_self_int(
74087432
return result
74097433

74107434
else:
7411-
# Repeat along specific dimension
7412-
# Apply Tile directly to the tensor instead of creating indices (more efficient)
7435+
# Repeat along specific dimension using simpler approach
7436+
# First, get the shape of the input tensor
7437+
original_shape = op.Shape(self)
74137438

7414-
# Expand tensor by adding dimension after target dim
7439+
# Use the approach similar to aten_repeat but for a single dimension
7440+
# Add a new dimension after the target dimension
74157441
self_expanded = op.Unsqueeze(self, [dim + 1])
74167442

7417-
# Get original shape to build tile pattern dynamically
7418-
original_shape = op.Shape(self)
7419-
num_dims = op.Size(original_shape)
7420-
7421-
# Build tile pattern: all 1s except position dim+1 which is 'repeats'
7422-
# Use ConstantOfShape to create array of 1s, then update specific position
7423-
ones_pattern = op.ConstantOfShape(
7424-
op.Add(num_dims, op.Constant(value_ints=[1])), # +1 for the new dimension
7443+
# Get the rank and build tile pattern
7444+
rank = op.Size(original_shape)
7445+
ones_before = op.ConstantOfShape(
7446+
op.Reshape(
7447+
op.Add(op.Constant(value_ints=[dim]), op.Constant(value_ints=[1])), [1]
7448+
),
7449+
op.Constant(value_ints=[1]),
7450+
)
7451+
repeat_val = op.Constant(value_ints=[repeats])
7452+
ones_after = op.ConstantOfShape(
7453+
op.Reshape(
7454+
op.Sub(
7455+
rank, op.Add(op.Constant(value_ints=[dim]), op.Constant(value_ints=[1]))
7456+
),
7457+
[1],
7458+
),
74257459
op.Constant(value_ints=[1]),
74267460
)
74277461

7428-
# Create indices and updates for ScatterND to set position dim+1 to 'repeats'
7429-
update_indices = op.Reshape(op.Constant(value_ints=[dim + 1]), [1, 1])
7430-
update_values = op.Constant(value_ints=[repeats])
7431-
7432-
tile_pattern = op.ScatterND(ones_pattern, update_indices, update_values)
7462+
# Concatenate to build tile pattern: [1, 1, ..., 1, repeats, 1, ..., 1]
7463+
tile_pattern = op.Concat(ones_before, repeat_val, ones_after, axis=0)
74337464

74347465
# Tile the expanded tensor
74357466
tiled = op.Tile(self_expanded, tile_pattern)
74367467

7437-
# Reshape to merge the two dimensions
7438-
# Calculate new shape: original shape with target dimension multiplied by repeats
7468+
# Reshape to merge the repeated dimension
7469+
# Calculate new shape
74397470
target_dim_size = op.Gather(original_shape, op.Constant(value_ints=[dim]))
74407471
new_target_size = op.Mul(target_dim_size, op.Constant(value_ints=[repeats]))
74417472

7442-
# Create new shape by updating the target dimension
7443-
update_shape_indices = op.Reshape(op.Constant(value_ints=[dim]), [1, 1])
7444-
new_shape = op.ScatterND(
7445-
original_shape, update_shape_indices, op.Reshape(new_target_size, [1])
7473+
# Build new shape by concatenating parts
7474+
shape_before = op.Slice(
7475+
original_shape, op.Constant(value_ints=[0]), op.Constant(value_ints=[dim])
7476+
)
7477+
shape_after = op.Slice(
7478+
original_shape,
7479+
op.Add(op.Constant(value_ints=[dim]), op.Constant(value_ints=[1])),
7480+
op.Constant(value_ints=[2147483647]),
7481+
)
7482+
new_shape = op.Concat(
7483+
shape_before, op.Reshape(new_target_size, [1]), shape_after, axis=0
74467484
)
74477485

74487486
result = op.Reshape(tiled, new_shape)

0 commit comments

Comments
 (0)