@@ -4541,80 +4541,135 @@ def aten_index_put(
45414541 See implementation of `torch.onnx.symbolic_opset11.index_put
45424542 <https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
45434543 """
4544-
4545- def _make_reshape_list_broadcastable (reshape_list , values_shape ):
4546- # Remove ones until the rank of reshape_list matches values_shape.
4547- while len (reshape_list ) > len (values_shape ) and 1 in reshape_list :
4548- reshape_list .remove (1 )
4549-
4550- # Now ensure each dimension is broadcastable:
4551- # This is mandatory when mixing basic and advanced indexing
4552- # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3)
4553- # the reshape list should be : [[2, 1], [1, 3], [2, 1]]
4554- for i , r in enumerate (reshape_list ):
4555- if r not in (1 , values_shape [i ]):
4556- value_index = values_shape .index (r )
4557- # Swap elements
4558- # For the example above the current reshape list is [1, 2] for last dim,
4559- # to make it broadcastable, we swap the elements
4560- reshape_list [value_index ], reshape_list [i ] = r , 1
4561-
4562- return reshape_list
4563-
4564- # Ensure the number of indices matches the tensor rank.
4544+ # Ensure the number of indices matches the tensor rank by appending trailing Nones.
45654545 self_rank = len (self .shape )
45664546 if len (indices ) < self_rank :
45674547 indices = list (indices ) + [None ] * (self_rank - len (indices ))
45684548
4569- # Get values shape
4570- values_shape = tuple (values .shape )
4549+ # The behavior of the op is dependent on whether there are advanced indices (i.e., non-scalar tensors)
4550+ # and whether these advanced indices are contiguous.
4551+
4552+ # Identify advanced indices.
4553+ def is_advanced_index (index ):
4554+ # Note: In this function, the index is assumed to be either None or an int64 Tensor.
4555+ return index is not None
4556+
4557+ advanced_indices : list [int ] = []
4558+ none_indices : list [int ] = []
4559+ num_advanced_indices = 0
4560+ num_none_indices = 0
4561+
4562+ for i , index in enumerate (indices ):
4563+ if is_advanced_index (index ):
4564+ advanced_indices .append (i )
4565+ num_advanced_indices += 1
4566+ elif index is None :
4567+ none_indices .append (i )
4568+ num_none_indices += 1
4569+ else :
4570+ raise ValueError (f"Unhandled index at position { i } : { index } " )
45714571
4572- index_vectors = []
4573- for i in range (self_rank ):
4574- if indices [i ] is None :
4575- # For a full slice along dim i, create a range index [0, self.shape[i]).
4576- idx = op .Range (0 , self .shape [i ], 1 )
4577- reshape_update = self .shape [i ]
4572+ self_shape = op .Shape (self )
4573+ if num_advanced_indices == 0 :
4574+ return op .Expand (values , self_shape )
4575+
4576+ # More than one advanced index may require broadcasting of index values
4577+ if num_advanced_indices > 1 :
4578+ # Check for special case where all advanced indices have same shape.
4579+ # But need to ensure none of the shapes have None as a dimension, which
4580+ # will invalidate equality-based check.
4581+ first_shape = indices [advanced_indices [0 ]].shape
4582+
4583+ def same_shape (other_shape : ir .Shape ) -> bool :
4584+ return (not any (d is None for d in other_shape )) and other_shape == first_shape
4585+
4586+ all_same_shape = all (same_shape (indices [i ].shape ) for i in advanced_indices )
4587+ if not all_same_shape :
4588+ # Broadcast advanced indices to a common shape.
4589+ advanced_index_rank = max (len (indices [i ].shape ) for i in advanced_indices )
4590+ shapes = []
4591+ for i in advanced_indices :
4592+ index = indices [i ]
4593+ index_rank = len (index .shape )
4594+ index_shape = op .Shape (index )
4595+ if index_rank < advanced_index_rank :
4596+ padding = op .Constant (
4597+ value_ints = [1 for _ in range (advanced_index_rank - index_rank )]
4598+ )
4599+ index_shape = op .Concat (padding , index_shape , axis = 0 )
4600+ shapes .append (index_shape )
4601+ advanced_indices_shape = op .Max (* shapes )
4602+ indices = [
4603+ op .Expand (index , advanced_indices_shape ) if is_advanced_index (index ) else index
4604+ for index in indices
4605+ ]
45784606 else :
4579- idx = indices [i ]
4580- reshape_update = math .prod (idx .shape )
4581- # when Index is more than 1D, flatten it and also the values shape
4582- # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
4583- # Indices -> (2*4,) and values shape (2*4, 32)
4584- if len (idx .shape ) > 1 :
4585- values_shape = (reshape_update , * values_shape [len (idx .shape ) :])
4586-
4587- # Flatten index (always working with 1D index in each dim)
4588- idx = op .Reshape (idx , [- 1 ])
4589-
4590- # Create a reshape pattern: one value per index dimension,
4591- # with the current dimension set to the update size.
4592- reshape_list = [1 ] * len (indices )
4593- reshape_list [i ] = reshape_update
4594-
4595- # Adjust the reshape list to match the values shape.
4596- reshape_list = _make_reshape_list_broadcastable (reshape_list , values_shape )
4597-
4598- # Reshape and expand the index.
4599- idx = op .Reshape (idx , reshape_list , allowzero = True )
4600- idx = op .Expand (idx , values_shape )
4601-
4602- # Flatten the index to 1D and unsqueeze to form a column vector.
4603- idx = op .Reshape (idx , [- 1 ])
4604- idx = op .Unsqueeze (idx , axes = [1 ])
4605- index_vectors .append (idx )
4606-
4607- # Concatenate the index vectors along axis=1 to form the final indices.
4608- new_index = op .Concat (* index_vectors , axis = 1 )
4609-
4610- # Flatten values to match the indices
4611- flat_values = op .Reshape (values , [- 1 ])
4612-
4613- if accumulate :
4614- result = op .ScatterND (self , new_index , flat_values , reduction = "add" )
4607+ advanced_indices_shape = op .Shape (indices [advanced_indices [0 ]])
4608+ advanced_index_rank = len (indices [advanced_indices [0 ]].shape )
46154609 else :
4616- result = op .ScatterND (self , new_index , flat_values )
4610+ advanced_indices_shape = op .Shape (indices [advanced_indices [0 ]])
4611+ advanced_index_rank = len (indices [advanced_indices [0 ]].shape )
4612+
4613+ # ONNX ScatterND supports only the case where all advanced indices appear first,
4614+ # followed by None indices. So, we need to transpose self and values so that the
4615+ # advanced indices appear first, and then transpose the result back to original
4616+ # order at the end.
4617+
4618+ none_indices_constant = op .Constant (value_ints = none_indices )
4619+ none_indices_shape = op .Gather (self_shape , none_indices_constant , axis = 0 )
4620+ target_shape = op .Concat (advanced_indices_shape , none_indices_shape , axis = 0 )
4621+ target_rank = advanced_index_rank + num_none_indices
4622+
4623+ # Generate indices tensor required by ONNX ScatterND by unsqueezing an extra dimension and
4624+ # concatenating all advanced indices along this new dimension.
4625+ minus_one = op .Constant (value_ints = [- 1 ])
4626+ advanced_index_values = [op .Unsqueeze (indices [i ], minus_one ) for i in advanced_indices ]
4627+ onnx_index = op .Concat (* advanced_index_values , axis = - 1 )
4628+
4629+ # Check if advanced indices are contiguous:
4630+ contiguous = True
4631+ if advanced_indices :
4632+ if advanced_indices [- 1 ] - advanced_indices [0 ] + 1 != len (advanced_indices ):
4633+ contiguous = False
4634+
4635+ # Bring advanced indices to front:
4636+ perm = advanced_indices + none_indices
4637+ transposed = op .Transpose (self , perm = perm )
4638+
4639+ # Expand values to match target shape:
4640+ # First, transpose values if necessary to match advanced indices order!
4641+ if contiguous :
4642+ # values may need to be transposed before expanding to target shape
4643+ num_padded_dims = target_rank - len (values .shape )
4644+ if num_padded_dims > 0 :
4645+ unsqueezed_dims = op .Constant (value_ints = list (range (num_padded_dims )))
4646+ values = op .Unsqueeze (values , unsqueezed_dims )
4647+ initial_none_index_positions = list (range (advanced_indices [0 ]))
4648+ advanced_index_replacement_positions = list (
4649+ range (advanced_indices [0 ], advanced_indices [0 ] + advanced_index_rank )
4650+ )
4651+ final_none_index_positions = list (
4652+ range (advanced_indices [0 ] + advanced_index_rank , target_rank )
4653+ )
4654+ values_perm = (
4655+ advanced_index_replacement_positions
4656+ + initial_none_index_positions
4657+ + final_none_index_positions
4658+ )
4659+ values = op .Transpose (values , perm = values_perm )
4660+
4661+ expanded_values = op .Expand (values , target_shape )
4662+
4663+ updated = op .ScatterND (
4664+ transposed , onnx_index , expanded_values , reduction = "add" if accumulate else None
4665+ )
4666+
4667+ # Inverse transpose to restore original dimension order:
46174668
4669+ inverse_perm = [0 ] * self_rank
4670+ for i , p in enumerate (perm ):
4671+ inverse_perm [p ] = i
4672+ result = op .Transpose (updated , perm = inverse_perm )
46184673 return result
46194674
46204675
0 commit comments