@@ -3151,6 +3151,7 @@ def aten_embedding_bag(
31513151 sparse : bool = False ,
31523152 per_sample_weights : Optional [TFloat ] = None ,
31533153 include_last_offset : bool = False ,
3154+ padding_idx : Optional [int ] = None ,
31543155) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
31553156 """embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)"""
31563157
@@ -3247,23 +3248,24 @@ def _aten_embedding_bag_onnx(
32473248
32483249 # Only compute the shape of other 3 outputs, we don't care the value
32493250 if mode == 0 : # sum
3250- offset2bag = op .Shape (indices , start = 0 , end = 0 ) # Generate empty tensor
3251+ offset2bag = op .Cast ( op . Shape (indices , start = 0 , end = 0 ), to = INT64 . dtype )
32513252 if op .Equal (include_last_offset , True ):
3252- bag_size = op .Expand (0 , op .Shape (offsets ))
3253+ bag_size = op .Cast (op .Expand (0 , op .Shape (offsets )), to = INT64 .dtype )
3254+ max_indices = op .Cast (op .Expand (0 , op .Shape (offsets )), to = INT64 .dtype )
32533255 else :
3254- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3255- max_indices = op .Expand (0 , op .Shape (bag_size ) )
3256+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
3257+ max_indices = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
32563258 elif mode == 1 : # mean
3257- offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
3258- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3259- max_indices = op .Expand (0 , op .Shape (bag_size ) )
3259+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices , start = 0 , end = 1 )), to = INT64 . dtype )
3260+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
3261+ max_indices = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
32603262 else : # max
3261- offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
3262- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3263+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices , start = 0 , end = 1 )), to = INT64 . dtype )
3264+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
32633265 # shape = (bag_size.dim[0], weight.dim[1])
32643266 dim_0 = op .Shape (bag_size , start = 0 , end = 1 )
32653267 dim_1 = op .Shape (weight , start = 1 , end = 2 )
3266- max_indices = op .Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 ))
3268+ max_indices = op .Cast ( op . Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 )), to = INT64 . dtype )
32673269
32683270 return result , offset2bag , bag_size , max_indices
32693271
@@ -3285,27 +3287,40 @@ def aten_embedding_bag_padding_idx(
32853287 sparse : bool = False ,
32863288 per_sample_weights : Optional [TFloat ] = None ,
32873289 include_last_offset : bool = False ,
3288- padding_idx : int = - 1 ,
3290+ padding_idx : Optional [ int ] = None ,
32893291) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
32903292 """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)
32913293
32923294 We add default values for the attributes to accommodate _embedding_bag as well:
32933295 _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1)
32943296 """
3295- assert padding_idx is not None , (
3296- "padding_idx must not be None. This is likely a dispatcher error"
3297- )
32983297
32993298 if per_sample_weights is None :
33003299 per_sample_weights = op .Expand (op .Constant (value_floats = [1.0 ]), op .Shape (indices ))
33013300 per_sample_weights = op .CastLike (per_sample_weights , weight )
33023301
3303- # Change padding_idx to positive value, -1 means the last index
3304- if padding_idx < 0 :
3305- padding_idx = weight .shape [0 ] + padding_idx
3302+ if padding_idx is not None :
3303+ # Call the existing function for handling padding_idx
3304+ result , offset2bag , bag_size , max_indices = _aten_embedding_bag_1d_padding_idx_onnx (
3305+ weight ,
3306+ indices ,
3307+ offsets ,
3308+ mode ,
3309+ per_sample_weights ,
3310+ include_last_offset ,
3311+ padding_idx ,
3312+ )
33063313
3307- result , offset2bag , bag_size , max_indices = _aten_embedding_bag_1d_padding_idx_onnx (
3308- weight , indices , offsets , mode , per_sample_weights , include_last_offset , padding_idx
3314+ return result , offset2bag , bag_size , max_indices
3315+
3316+ # When padding_idx is None, use the standard embedding_bag implementation
3317+ result , offset2bag , bag_size , max_indices = _aten_embedding_bag_onnx (
3318+ weight ,
3319+ indices ,
3320+ offsets ,
3321+ mode ,
3322+ per_sample_weights ,
3323+ include_last_offset ,
33093324 )
33103325
33113326 return result , offset2bag , bag_size , max_indices
@@ -3322,6 +3337,12 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
33223337 padding_idx : int ,
33233338) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
33243339 neg_1 = op .Constant (value_ints = [- 1 ])
3340+
3341+ num_embeddings = op .Shape (weight , start = 0 , end = 1 ) # Get number of rows in weight
3342+ num_embeddings_scalar = op .Squeeze (num_embeddings )
3343+ if padding_idx < 0 :
3344+ padding_idx = padding_idx + num_embeddings_scalar
3345+
33253346 # Get weight out according to indices,
33263347 # e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
33273348 indices_weight = op .Gather (weight , indices )
@@ -3357,7 +3378,10 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
33573378 cond_2 = j < end_pos
33583379 while cond_2 :
33593380 index = op .Gather (indices , j )
3360- if not op .Equal (index , padding_idx ):
3381+ normalized_index = index
3382+ if index < 0 :
3383+ normalized_index = index + num_embeddings_scalar
3384+ if not op .Equal (normalized_index , padding_idx ):
33613385 # Something like the 'append' operation
33623386 curr_offsets = op .Concat (curr_offsets , op .Reshape (j , neg_1 ), axis = 0 )
33633387 j = j + 1
@@ -3386,23 +3410,24 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
33863410 result = op .CastLike (result , weight )
33873411
33883412 if mode == 0 : # sum
3389- offset2bag = op .Expand (0 , op .Shape (indices ))
3413+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices )), to = INT64 . dtype )
33903414 if op .Equal (include_last_offset , True ):
3391- bag_size = op .Expand (0 , op .Shape (offsets ))
3415+ bag_size = op .Cast (op .Expand (0 , op .Shape (offsets )), to = INT64 .dtype )
3416+ max_indices = op .Cast (op .Expand (0 , op .Shape (offsets )), to = INT64 .dtype )
33923417 else :
3393- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3394- max_indices = op .Expand (0 , op .Shape (bag_size ) )
3418+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
3419+ max_indices = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
33953420 elif mode == 1 : # mean
3396- offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
3397- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3398- max_indices = op .Expand (0 , op .Shape (bag_size ) )
3421+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices , start = 0 , end = 1 )), to = INT64 . dtype )
3422+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
3423+ max_indices = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
33993424 else : # mode == 2, max
3400- offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
3401- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3425+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices , start = 0 , end = 1 )), to = INT64 . dtype )
3426+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
34023427 # shape = (bag_size.dim[0], weight.dim[1])
34033428 dim_0 = op .Shape (bag_size , start = 0 , end = 1 )
34043429 dim_1 = op .Shape (weight , start = 1 , end = 2 )
3405- max_indices = op .Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 ))
3430+ max_indices = op .Cast ( op . Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 )), to = INT64 . dtype )
34063431
34073432 return result , offset2bag , bag_size , max_indices
34083433
@@ -4382,7 +4407,6 @@ def aten_grid_sampler(
43824407 padding_mode_options = ("zeros" , "border" , "reflection" )
43834408 padding_mode_str = padding_mode_options [padding_mode ]
43844409
4385- # Only one onnx Op so don't put into private function
43864410 return op .GridSample (
43874411 input ,
43884412 grid ,
@@ -4408,7 +4432,6 @@ def aten_grid_sampler_2d(
44084432 padding_mode_options = ("zeros" , "border" , "reflection" )
44094433 padding_mode_str = padding_mode_options [padding_mode ]
44104434
4411- # Only one onnx Op so don't put into private function
44124435 return op .GridSample (
44134436 input ,
44144437 grid ,
@@ -4698,7 +4721,7 @@ def _aten_index_onnx(
46984721 if _has_none_in_middle (indices ):
46994722 # If there is None in the middle, Advanced Indexing cannot decide where to put
47004723 # the new dimensions. So it places them in the front, like GatherND does.
4701- return op . Identity ( self )
4724+ return self
47024725
47034726 # When the indices are consecutive, Advanced Indexing will place the new dimensions
47044727 # (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
@@ -4744,7 +4767,9 @@ def _aten_index_onnx(
47444767
47454768
47464769@torch_op (("aten::index.Tensor" , "aten::_unsafe_index.Tensor" ), trace_only = True )
4747- def aten_index (self : TensorType , indices : Sequence [Optional [INT64 ]]) -> TensorType :
4770+ def aten_index (
4771+ self : TensorType , indices : Sequence [Optional [Union [INT64 , BOOL ]]]
4772+ ) -> TensorType :
47484773 """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
47494774
47504775 NOTE: Understanding `aten::index`
@@ -4764,17 +4789,19 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy
47644789
47654790 None in `indices` are like fillers for dimensions that cannot be removed in the process.
47664791 """
4792+ # Handle Boolean indexing first
4793+ if any (index is not None and index .dtype == ir .DataType .BOOL for index in indices ):
4794+ return _aten_index_bool (self , indices )
47674795
47684796 index_ranks = [len (index .shape ) for index in indices if index is not None ]
47694797
47704798 return _aten_index_onnx (self , indices , index_ranks )
47714799
47724800
4773- @torch_op (("aten::index.Tensor" , "aten::_unsafe_index.Tensor" ), trace_only = True )
4774- def aten_index_bool (self : TensorType , indices : Sequence [Optional [BOOL ]]) -> TensorType : # pylint: disable=inconsistent-return-statements
4801+ def _aten_index_bool (self : TensorType , indices : Sequence [Optional [BOOL ]]) -> TensorType :
47754802 index_ranks = [len (index .shape ) for index in indices if index is not None ]
47764803
4777- if index_ranks [ 0 ] == 1 :
4804+ if all ( rank == 1 for rank in index_ranks ) :
47784805 # indices contains scalar only.
47794806 new_indices = [
47804807 op .Transpose (op .NonZero (index ), perm = [1 , 0 ]) if index is not None else None
@@ -4784,6 +4811,7 @@ def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Tens
47844811 op .Squeeze (index , axes = [1 ]) if index is not None else None for index in new_indices
47854812 ]
47864813 return _aten_index_onnx (self , new_indices , index_ranks )
4814+
47874815 else :
47884816 input_rank = len (self .shape )
47894817 # Prepare perm for transposing self tensor.
@@ -4800,15 +4828,19 @@ def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Tens
48004828 if index is None :
48014829 self = op .Transpose (self , perm = trans_perm )
48024830 count_of_none += 1
4803- else :
4804- new_indices = op .Transpose (op .NonZero (index ), perm = [1 , 0 ])
4805- result = op .GatherND (self , new_indices , batch_dims = 0 )
4806- finla_rank = input_rank - (len (index .shape ) - 1 )
4807- trans_perm = list (range (finla_rank ))
4808- trans_perm = trans_perm [- 1 :] + trans_perm [:- 1 ]
4809- for _ in range (count_of_none ):
4810- result = op .Transpose (result , perm = trans_perm )
4811- return result
4831+ continue
4832+
4833+ new_indices = op .Transpose (op .NonZero (index ), perm = [1 , 0 ])
4834+ result = op .GatherND (self , new_indices , batch_dims = 0 )
4835+ final_rank = input_rank - (len (index .shape ) - 1 )
4836+ trans_perm = list (range (final_rank ))
4837+ trans_perm = trans_perm [- 1 :] + trans_perm [:- 1 ]
4838+ for _ in range (count_of_none ):
4839+ result = op .Transpose (result , perm = trans_perm )
4840+ # FIXME(justinchuby): Even though this logic passes the tests, it still looks strange:
4841+ # why does it return early here instead of continuing to process the remaining indices?
4842+ # I think the assumption here is that there can be only one Boolean index in the indices list?
4843+ return result
48124844
48134845
48144846def aten_index_add (
@@ -4830,7 +4862,7 @@ def aten_index_copy(
48304862@torch_op (("aten::index_put" , "aten::_unsafe_index_put" ), trace_only = True )
48314863def aten_index_put (
48324864 self : TReal ,
4833- indices : Sequence [INT64 ],
4865+ indices : Sequence [Optional [ Union [ INT64 , BOOL ]] ],
48344866 values : TReal ,
48354867 accumulate : bool = False ,
48364868) -> TReal :
@@ -4839,6 +4871,9 @@ def aten_index_put(
48394871 See implementation of `torch.onnx.symbolic_opset11.index_put
48404872 <https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
48414873 """
4874+ if any (index is not None and index .dtype == BOOL .dtype for index in indices ):
4875+ return _aten_index_put_bool (self , indices , values , accumulate )
4876+
48424877 # Ensure the number of indices matches the tensor rank by appending trailing Nones.
48434878 self_rank = len (self .shape )
48444879 if len (indices ) < self_rank :
@@ -4971,8 +5006,7 @@ def same_shape(other_shape: ir.Shape) -> bool:
49715006 return result
49725007
49735008
4974- @torch_op ("aten::index_put" , trace_only = True )
4975- def aten_index_put_bool (
5009+ def _aten_index_put_bool (
49765010 self : TReal ,
49775011 indices : Sequence [BOOL ],
49785012 values : TReal ,
0 commit comments