@@ -4259,15 +4259,48 @@ def aten_index_put(
42594259 See implementation of `torch.onnx.symbolic_opset11.index_put
42604260 <https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
42614261 """
4262+ # Pad indices with None so It has the same rank as self
4263+ self_rank = len (self .shape )
4264+ if len (indices ) < self_rank :
4265+ indices = list (indices ) + [None ] * (self_rank - len (indices ))
42624266
4263- # TODO(justinchuby): Handle when indicies has more than one element
4264- index = indices [0 ]
4265- new_index = op .Unsqueeze (index , [- 1 ])
4267+ values_shape = values .shape .numpy ()
4268+
4269+ index_vectors = []
4270+ for i , index in enumerate (indices ):
4271+ if index is None :
4272+ # For a full slice, create a range.
4273+ index_vector = op .Range (start = 0 , limit = values_shape [i ], delta = 1 )
4274+ else :
4275+ index_vector = index
4276+
4277+ # Shape vector with 1s, except at axis i.
4278+ shape_vector = [1 ] * self_rank
4279+ shape_vector [i ] = values_shape [i ]
4280+
4281+ # Reshape index_vector so that only the i-th dimension matches values_shape[i]
4282+ reshaped_index_vector = op .Reshape (index_vector , shape_vector )
4283+
4284+ # Expand reshaped_index_vector to match the full shape of values
4285+ expanded_index_vector = op .Expand (reshaped_index_vector , values_shape )
4286+
4287+ # Flatten into a 1D vector
4288+ column_index_vector = op .Reshape (expanded_index_vector , [- 1 ])
4289+
4290+ # Convert into a column vector to prepare for concatenation
4291+ column_index_vector = op .Unsqueeze (column_index_vector , axes = [1 ])
4292+ index_vectors .append (column_index_vector )
4293+
4294+ # Contains all indices to be upadated
4295+ new_index = op .Concat (* index_vectors , axis = 1 )
4296+
4297+ # Flatten values to match the indices
4298+ flat_values = op .Reshape (values , [- 1 ])
42664299
42674300 if accumulate :
4268- result = op .ScatterND (self , new_index , values , reduction = "add" )
4301+ result = op .ScatterND (self , new_index , flat_values , reduction = "add" )
42694302 else :
4270- result = op .ScatterND (self , new_index , values )
4303+ result = op .ScatterND (self , new_index , flat_values )
42714304
42724305 return result
42734306
0 commit comments