Skip to content

Commit a3883a6

Browse files
authored
Update aten_index_put implementation (#2712)
1 parent 5583f96 commit a3883a6

File tree

2 files changed

+368
-68
lines changed

2 files changed

+368
-68
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 121 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4541,80 +4541,135 @@ def aten_index_put(
45414541
See implementation of `torch.onnx.symbolic_opset11.index_put
45424542
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
45434543
"""
4544-
4545-
def _make_reshape_list_broadcastable(reshape_list, values_shape):
4546-
# Remove ones until the rank of reshape_list matches values_shape.
4547-
while len(reshape_list) > len(values_shape) and 1 in reshape_list:
4548-
reshape_list.remove(1)
4549-
4550-
# Now ensure each dimension is broadcastable:
4551-
# This is mandatory when mixing basic and advanced indexing
4552-
# Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3)
4553-
# the reshape list should be : [[2, 1], [1, 3], [2, 1]]
4554-
for i, r in enumerate(reshape_list):
4555-
if r not in (1, values_shape[i]):
4556-
value_index = values_shape.index(r)
4557-
# Swap elements
4558-
# For the example above the current reshape list is [1, 2] for last dim,
4559-
# to make it broadcastable, we swap the elements
4560-
reshape_list[value_index], reshape_list[i] = r, 1
4561-
4562-
return reshape_list
4563-
4564-
# Ensure the number of indices matches the tensor rank.
4544+
# Ensure the number of indices matches the tensor rank by appending trailing Nones.
45654545
self_rank = len(self.shape)
45664546
if len(indices) < self_rank:
45674547
indices = list(indices) + [None] * (self_rank - len(indices))
45684548

4569-
# Get values shape
4570-
values_shape = tuple(values.shape)
4549+
# The behavior of the op is dependent on whether there are advanced indices (i.e., non-scalar tensors)
4550+
# and whether these advanced indices are contiguous.
4551+
4552+
# Identify advanced indices.
4553+
def is_advanced_index(index):
4554+
# Note: In this function, the index is assumed to be either None or an int64 Tensor.
4555+
return index is not None
4556+
4557+
advanced_indices: list[int] = []
4558+
none_indices: list[int] = []
4559+
num_advanced_indices = 0
4560+
num_none_indices = 0
4561+
4562+
for i, index in enumerate(indices):
4563+
if is_advanced_index(index):
4564+
advanced_indices.append(i)
4565+
num_advanced_indices += 1
4566+
elif index is None:
4567+
none_indices.append(i)
4568+
num_none_indices += 1
4569+
else:
4570+
raise ValueError(f"Unhandled index at position {i}: {index}")
45714571

4572-
index_vectors = []
4573-
for i in range(self_rank):
4574-
if indices[i] is None:
4575-
# For a full slice along dim i, create a range index [0, self.shape[i]).
4576-
idx = op.Range(0, self.shape[i], 1)
4577-
reshape_update = self.shape[i]
4572+
self_shape = op.Shape(self)
4573+
if num_advanced_indices == 0:
4574+
return op.Expand(values, self_shape)
4575+
4576+
# More than one advanced index may require broadcasting of index values
4577+
if num_advanced_indices > 1:
4578+
# Check for special case where all advanced indices have same shape.
4579+
# But need to ensure none of the shapes have None as a dimension, which
4580+
# will invalidate equality-based check.
4581+
first_shape = indices[advanced_indices[0]].shape
4582+
4583+
def same_shape(other_shape: ir.Shape) -> bool:
4584+
return (not any(d is None for d in other_shape)) and other_shape == first_shape
4585+
4586+
all_same_shape = all(same_shape(indices[i].shape) for i in advanced_indices)
4587+
if not all_same_shape:
4588+
# Broadcast advanced indices to a common shape.
4589+
advanced_index_rank = max(len(indices[i].shape) for i in advanced_indices)
4590+
shapes = []
4591+
for i in advanced_indices:
4592+
index = indices[i]
4593+
index_rank = len(index.shape)
4594+
index_shape = op.Shape(index)
4595+
if index_rank < advanced_index_rank:
4596+
padding = op.Constant(
4597+
value_ints=[1 for _ in range(advanced_index_rank - index_rank)]
4598+
)
4599+
index_shape = op.Concat(padding, index_shape, axis=0)
4600+
shapes.append(index_shape)
4601+
advanced_indices_shape = op.Max(*shapes)
4602+
indices = [
4603+
op.Expand(index, advanced_indices_shape) if is_advanced_index(index) else index
4604+
for index in indices
4605+
]
45784606
else:
4579-
idx = indices[i]
4580-
reshape_update = math.prod(idx.shape)
4581-
# when Index is more than 1D, flatten it and also the values shape
4582-
# Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
4583-
# Indices -> (2*4,) and values shape (2*4, 32)
4584-
if len(idx.shape) > 1:
4585-
values_shape = (reshape_update, *values_shape[len(idx.shape) :])
4586-
4587-
# Flatten index (always working with 1D index in each dim)
4588-
idx = op.Reshape(idx, [-1])
4589-
4590-
# Create a reshape pattern: one value per index dimension,
4591-
# with the current dimension set to the update size.
4592-
reshape_list = [1] * len(indices)
4593-
reshape_list[i] = reshape_update
4594-
4595-
# Adjust the reshape list to match the values shape.
4596-
reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape)
4597-
4598-
# Reshape and expand the index.
4599-
idx = op.Reshape(idx, reshape_list, allowzero=True)
4600-
idx = op.Expand(idx, values_shape)
4601-
4602-
# Flatten the index to 1D and unsqueeze to form a column vector.
4603-
idx = op.Reshape(idx, [-1])
4604-
idx = op.Unsqueeze(idx, axes=[1])
4605-
index_vectors.append(idx)
4606-
4607-
# Concatenate the index vectors along axis=1 to form the final indices.
4608-
new_index = op.Concat(*index_vectors, axis=1)
4609-
4610-
# Flatten values to match the indices
4611-
flat_values = op.Reshape(values, [-1])
4612-
4613-
if accumulate:
4614-
result = op.ScatterND(self, new_index, flat_values, reduction="add")
4607+
advanced_indices_shape = op.Shape(indices[advanced_indices[0]])
4608+
advanced_index_rank = len(indices[advanced_indices[0]].shape)
46154609
else:
4616-
result = op.ScatterND(self, new_index, flat_values)
4610+
advanced_indices_shape = op.Shape(indices[advanced_indices[0]])
4611+
advanced_index_rank = len(indices[advanced_indices[0]].shape)
4612+
4613+
# ONNX ScatterND supports only the case where all advanced indices appear first,
4614+
# followed by None indices. So, we need to transpose self and values so that the
4615+
# advanced indices appear first, and then transpose the result back to original
4616+
# order at the end.
4617+
4618+
none_indices_constant = op.Constant(value_ints=none_indices)
4619+
none_indices_shape = op.Gather(self_shape, none_indices_constant, axis=0)
4620+
target_shape = op.Concat(advanced_indices_shape, none_indices_shape, axis=0)
4621+
target_rank = advanced_index_rank + num_none_indices
4622+
4623+
# Generate indices tensor required by ONNX ScatterND by unsqueezing an extra dimension and
4624+
# concatenating all advanced indices along this new dimension.
4625+
minus_one = op.Constant(value_ints=[-1])
4626+
advanced_index_values = [op.Unsqueeze(indices[i], minus_one) for i in advanced_indices]
4627+
onnx_index = op.Concat(*advanced_index_values, axis=-1)
4628+
4629+
# Check if advanced indices are contiguous:
4630+
contiguous = True
4631+
if advanced_indices:
4632+
if advanced_indices[-1] - advanced_indices[0] + 1 != len(advanced_indices):
4633+
contiguous = False
4634+
4635+
# Bring advanced indices to front:
4636+
perm = advanced_indices + none_indices
4637+
transposed = op.Transpose(self, perm=perm)
4638+
4639+
# Expand values to match target shape:
4640+
# First, transpose values if necessary to match advanced indices order!
4641+
if contiguous:
4642+
# values may need to be transposed before expanding to target shape
4643+
num_padded_dims = target_rank - len(values.shape)
4644+
if num_padded_dims > 0:
4645+
unsqueezed_dims = op.Constant(value_ints=list(range(num_padded_dims)))
4646+
values = op.Unsqueeze(values, unsqueezed_dims)
4647+
initial_none_index_positions = list(range(advanced_indices[0]))
4648+
advanced_index_replacement_positions = list(
4649+
range(advanced_indices[0], advanced_indices[0] + advanced_index_rank)
4650+
)
4651+
final_none_index_positions = list(
4652+
range(advanced_indices[0] + advanced_index_rank, target_rank)
4653+
)
4654+
values_perm = (
4655+
advanced_index_replacement_positions
4656+
+ initial_none_index_positions
4657+
+ final_none_index_positions
4658+
)
4659+
values = op.Transpose(values, perm=values_perm)
4660+
4661+
expanded_values = op.Expand(values, target_shape)
4662+
4663+
updated = op.ScatterND(
4664+
transposed, onnx_index, expanded_values, reduction="add" if accumulate else None
4665+
)
4666+
4667+
# Inverse transpose to restore original dimension order:
46174668

4669+
inverse_perm = [0] * self_rank
4670+
for i, p in enumerate(perm):
4671+
inverse_perm[p] = i
4672+
result = op.Transpose(updated, perm=inverse_perm)
46184673
return result
46194674

46204675

0 commit comments

Comments
 (0)