@@ -3151,6 +3151,7 @@ def aten_embedding_bag(
31513151 sparse : bool = False ,
31523152 per_sample_weights : Optional [TFloat ] = None ,
31533153 include_last_offset : bool = False ,
3154+ padding_idx : Optional [int ] = None ,
31543155) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
31553156 """embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)"""
31563157
@@ -3247,23 +3248,24 @@ def _aten_embedding_bag_onnx(
32473248
32483249 # Only compute the shape of other 3 outputs, we don't care the value
32493250 if mode == 0 : # sum
3250- offset2bag = op .Shape (indices , start = 0 , end = 0 ) # Generate empty tensor
3251+ offset2bag = op .Cast ( op . Shape (indices , start = 0 , end = 0 ), to = INT64 . dtype )
32513252 if op .Equal (include_last_offset , True ):
3252- bag_size = op .Expand (0 , op .Shape (offsets ))
3253+ bag_size = op .Cast (op .Expand (0 , op .Shape (offsets )), to = INT64 .dtype )
3254+ max_indices = op .Cast (op .Expand (0 , op .Shape (offsets )), to = INT64 .dtype )
32533255 else :
3254- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3255- max_indices = op .Expand (0 , op .Shape (bag_size ) )
3256+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
3257+ max_indices = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
32563258 elif mode == 1 : # mean
3257- offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
3258- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3259- max_indices = op .Expand (0 , op .Shape (bag_size ) )
3259+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices , start = 0 , end = 1 )), to = INT64 . dtype )
3260+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
3261+ max_indices = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
32603262 else : # max
3261- offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
3262- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3263+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices , start = 0 , end = 1 )), to = INT64 . dtype )
3264+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
32633265 # shape = (bag_size.dim[0], weight.dim[1])
32643266 dim_0 = op .Shape (bag_size , start = 0 , end = 1 )
32653267 dim_1 = op .Shape (weight , start = 1 , end = 2 )
3266- max_indices = op .Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 ))
3268+ max_indices = op .Cast ( op . Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 )), to = INT64 . dtype )
32673269
32683270 return result , offset2bag , bag_size , max_indices
32693271
@@ -3285,27 +3287,40 @@ def aten_embedding_bag_padding_idx(
32853287 sparse : bool = False ,
32863288 per_sample_weights : Optional [TFloat ] = None ,
32873289 include_last_offset : bool = False ,
3288- padding_idx : int = - 1 ,
3290+ padding_idx : Optional [ int ] = None ,
32893291) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
32903292 """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)
32913293
32923294 We add default values for the attributes to accommodate _embedding_bag as well:
32933295 _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1)
32943296 """
3295- assert padding_idx is not None , (
3296- "padding_idx must not be None. This is likely a dispatcher error"
3297- )
32983297
32993298 if per_sample_weights is None :
33003299 per_sample_weights = op .Expand (op .Constant (value_floats = [1.0 ]), op .Shape (indices ))
33013300 per_sample_weights = op .CastLike (per_sample_weights , weight )
33023301
3303- # Change padding_idx to positive value, -1 means the last index
3304- if padding_idx < 0 :
3305- padding_idx = weight .shape [0 ] + padding_idx
3302+ if padding_idx is not None :
3303+ # Call the existing function for handling padding_idx
3304+ result , offset2bag , bag_size , max_indices = _aten_embedding_bag_1d_padding_idx_onnx (
3305+ weight ,
3306+ indices ,
3307+ offsets ,
3308+ mode ,
3309+ per_sample_weights ,
3310+ include_last_offset ,
3311+ padding_idx ,
3312+ )
3313+
3314+ return result , offset2bag , bag_size , max_indices
33063315
3307- result , offset2bag , bag_size , max_indices = _aten_embedding_bag_1d_padding_idx_onnx (
3308- weight , indices , offsets , mode , per_sample_weights , include_last_offset , padding_idx
3316+ # When padding_idx is None, use the standard embedding_bag implementation
3317+ result , offset2bag , bag_size , max_indices = _aten_embedding_bag_onnx (
3318+ weight ,
3319+ indices ,
3320+ offsets ,
3321+ mode ,
3322+ per_sample_weights ,
3323+ include_last_offset ,
33093324 )
33103325
33113326 return result , offset2bag , bag_size , max_indices
@@ -3322,6 +3337,12 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
33223337 padding_idx : int ,
33233338) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
33243339 neg_1 = op .Constant (value_ints = [- 1 ])
3340+
3341+ num_embeddings = op .Shape (weight , start = 0 , end = 1 ) # Get number of rows in weight
3342+ num_embeddings_scalar = op .Squeeze (num_embeddings )
3343+ if padding_idx < 0 :
3344+ padding_idx = padding_idx + num_embeddings_scalar
3345+
33253346 # Get weight out according to indices,
33263347 # e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
33273348 indices_weight = op .Gather (weight , indices )
@@ -3357,7 +3378,10 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
33573378 cond_2 = j < end_pos
33583379 while cond_2 :
33593380 index = op .Gather (indices , j )
3360- if not op .Equal (index , padding_idx ):
3381+ normalized_index = index
3382+ if index < 0 :
3383+ normalized_index = index + num_embeddings_scalar
3384+ if not op .Equal (normalized_index , padding_idx ):
33613385 # Something like the 'append' operation
33623386 curr_offsets = op .Concat (curr_offsets , op .Reshape (j , neg_1 ), axis = 0 )
33633387 j = j + 1
@@ -3386,23 +3410,24 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
33863410 result = op .CastLike (result , weight )
33873411
33883412 if mode == 0 : # sum
3389- offset2bag = op .Expand (0 , op .Shape (indices ))
3413+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices )), to = INT64 . dtype )
33903414 if op .Equal (include_last_offset , True ):
3391- bag_size = op .Expand (0 , op .Shape (offsets ))
3415+ bag_size = op .Cast (op .Expand (0 , op .Shape (offsets )), to = INT64 .dtype )
3416+ max_indices = op .Cast (op .Expand (0 , op .Shape (offsets )), to = INT64 .dtype )
33923417 else :
3393- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3394- max_indices = op .Expand (0 , op .Shape (bag_size ) )
3418+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
3419+ max_indices = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
33953420 elif mode == 1 : # mean
3396- offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
3397- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3398- max_indices = op .Expand (0 , op .Shape (bag_size ) )
3421+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices , start = 0 , end = 1 )), to = INT64 . dtype )
3422+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
3423+ max_indices = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
33993424 else : # mode == 2, max
3400- offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
3401- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3425+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices , start = 0 , end = 1 )), to = INT64 . dtype )
3426+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
34023427 # shape = (bag_size.dim[0], weight.dim[1])
34033428 dim_0 = op .Shape (bag_size , start = 0 , end = 1 )
34043429 dim_1 = op .Shape (weight , start = 1 , end = 2 )
3405- max_indices = op .Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 ))
3430+ max_indices = op .Cast ( op . Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 )), to = INT64 . dtype )
34063431
34073432 return result , offset2bag , bag_size , max_indices
34083433
0 commit comments