Skip to content

Commit a22cbfb

Browse files
Copilotxadupre
andcommitted
Revert self_tensor to original clean implementation, keep self_int optimization
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
1 parent d291ae3 commit a22cbfb

File tree

1 file changed

+29
-28
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+29
-28
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7331,58 +7331,59 @@ def aten_repeat_interleave_self_tensor(
73317331
# Convert repeats to int64 for ONNX compatibility
73327332
repeats_int64 = op.Cast(repeats, to=INT64.dtype)
73337333

7334-
# Use an approach similar to self_int but adapted for variable repeats
7335-
# The key optimization: avoid creating large intermediate index tensors
7336-
7337-
# Get cumulative sum to determine output positions
7334+
# Get cumulative sum of repeats to find the boundaries
73387335
cumsum = op.CumSum(repeats_int64, axis=0)
73397336
total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0)
73407337

7341-
# Create output indices
7338+
# Create output tensor indices
73427339
output_range = op.Range(
73437340
op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1])
73447341
)
73457342

7346-
# More efficient searchsorted: find input index for each output position
7347-
# Broadcast to find positions where output_idx < cumsum_val
7348-
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # [1, n_elements]
7349-
output_expanded = op.Unsqueeze(output_range, [1]) # [total_size, 1]
7343+
# Find which original index each output position corresponds to
7344+
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # Shape: [1, len(repeats)]
7345+
output_range_expanded = op.Unsqueeze(output_range, [1]) # Shape: [total_size, 1]
7346+
7347+
# Find positions where output_range < cumsum
7348+
mask = op.Less(
7349+
output_range_expanded, cumsum_expanded
7350+
) # Shape: [total_size, len(repeats)]
73507351

7351-
# Find first position where output_idx < cumsum_val
7352-
mask = op.Less(output_expanded, cumsum_expanded) # [total_size, n_elements]
7353-
input_indices = op.ArgMax(op.Cast(mask, to=INT64.dtype), axis=1, keepdims=False)
7352+
# For each row, find the first True position
7353+
indices = op.ArgMax(op.Cast(mask, to=INT64.dtype), axis=1, keepdims=False)
73547354

7355-
# Gather the actual values
7356-
result = op.Gather(self_flat, input_indices, axis=0)
7355+
# Gather elements from the flattened tensor
7356+
result = op.Gather(self_flat, indices, axis=0)
73577357
return result
73587358

73597359
else:
7360-
# Repeat along specific dimension - use approach similar to optimized self_int
7360+
# Repeat along specific dimension
73617361
# Convert repeats to int64 for ONNX compatibility
73627362
repeats_int64 = op.Cast(repeats, to=INT64.dtype)
73637363

7364-
# Use a more efficient approach similar to self_int optimization
7365-
# The challenge is that we have variable repeat counts per slice
7366-
7367-
# Get cumulative sum to find boundaries (this part is necessary for variable repeats)
7364+
# Get cumulative sum of repeats to find the boundaries
73687365
cumsum = op.CumSum(repeats_int64, axis=0)
73697366
total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0)
73707367

7371-
# Create output indices for the dimension
7368+
# Create output tensor indices for the specified dimension
73727369
output_range = op.Range(
73737370
op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1])
73747371
)
73757372

7376-
# Efficient mapping from output positions to input indices
7377-
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # [1, n_slices]
7378-
output_expanded = op.Unsqueeze(output_range, [1]) # [total_size, 1]
7373+
# Find which original index each output position corresponds to
7374+
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # Shape: [1, len(repeats)]
7375+
output_range_expanded = op.Unsqueeze(output_range, [1]) # Shape: [total_size, 1]
7376+
7377+
# Find positions where output_range < cumsum
7378+
mask = op.Less(
7379+
output_range_expanded, cumsum_expanded
7380+
) # Shape: [total_size, len(repeats)]
73797381

7380-
# Find input slice index for each output position
7381-
mask = op.Less(output_expanded, cumsum_expanded) # [total_size, n_slices]
7382-
input_indices = op.ArgMax(op.Cast(mask, to=INT64.dtype), axis=1, keepdims=False)
7382+
# For each row, find the first True position
7383+
indices = op.ArgMax(op.Cast(mask, to=INT64.dtype), axis=1, keepdims=False)
73837384

7384-
# Gather slices along the specified dimension
7385-
result = op.Gather(self, input_indices, axis=dim)
7385+
# Gather elements along the specified dimension
7386+
result = op.Gather(self, indices, axis=dim)
73867387
return result
73877388

73887389

0 commit comments

Comments
 (0)