@@ -4468,15 +4468,21 @@ def _aten_index_put_dynamic(
44684468 values : TReal ,
44694469 accumulate : bool = False ,
44704470) -> TReal :
4471+ def _1dint (i : int ):
4472+ return op .Constant (value_ints = ir .AttrInt64s ("value_ints" , [i ]))
4473+
4474+ def _0dint (i : int ):
4475+ return op .Constant (value_int = ir .AttrInt64 ("value_int" , i ))
4476+
44714477 def _make_range_or_cast (ind , shape_x , static_shape : bool , dim : int ):
44724478 if ind is not None :
44734479 return op .Cast (ind , to = INT64 .dtype ), False
44744480 return (
44754481 op .Cast (
44764482 op .Range ( # Range does not return a typed result
4477- 0 ,
4483+ _0dint ( 0 ) ,
44784484 op .Squeeze (op .Shape (x , start = dim , end = dim + 1 )),
4479- 1 ,
4485+ _0dint ( 1 ) ,
44804486 ),
44814487 to = INT64 .dtype ,
44824488 ),
@@ -4500,21 +4506,21 @@ def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int):
45004506 if expanded :
45014507 exped .append ((i , ind ))
45024508 expand_value_shape .append (op .Shape (x , start = i , end = i + 1 ))
4503- reshape_value_shape2 .append ([ 1 ] )
4509+ reshape_value_shape2 .append (_1dint ( 1 ) )
45044510 else :
4505- expand_value_shape .append ([ 1 ] )
4511+ expand_value_shape .append (_1dint ( 1 ) )
45064512 reshape_value_shape2 .append (op .Shape (ind ))
45074513 fixed .append ((i , ind ))
45084514
4509- reshape_value_shape1 = [1 ] * len (indices )
4515+ reshape_value_shape1 = [_1dint ( 1 ) ] * len (indices )
45104516 if len (fixed ) <= 1 :
45114517 reshape_value_shape1 = None
45124518 elif fixed :
4513- reshape_value_shape1 [fixed [- 1 ][0 ]] = - 1
4519+ reshape_value_shape1 [fixed [- 1 ][0 ]] = _1dint ( - 1 )
45144520
45154521 def _mkstride (x , i ):
45164522 if i >= len (x .shape ) - 1 :
4517- return [ 1 ]
4523+ return _1dint ( 1 )
45184524 if i == len (x .shape ) - 2 :
45194525 return op .Shape (x , start = i + 1 )
45204526 return op .ReduceProd (op .Shape (x , start = i + 1 ), keepdims = 1 )
@@ -4547,9 +4553,9 @@ def _mkstride(x, i):
45474553 # Bug here: Error calling operator 'Concat' with args
45484554 # (SymbolicTensor(name='anonymous:124529632436112', producer=anonymous_node:124529631522416, index=0), [1], [1])
45494555 expanded_values = op .Expand (expanded_values , op .Concat (* expand_value_shape , axis = 0 ))
4550- flat_ind = op .Reshape (unflat , [ - 1 ] )
4551- expanded_values = op .Reshape (expanded_values , [ - 1 ] )
4552- flat_x = op .Reshape (x , [ - 1 ] )
4556+ flat_ind = op .Reshape (unflat , _1dint ( - 1 ) )
4557+ expanded_values = op .Reshape (expanded_values , _1dint ( - 1 ) )
4558+ flat_x = op .Reshape (x , _1dint ( - 1 ) )
45534559 scat_kwargs = {"reduction" : "add" } if accumulate else {}
45544560 flat_up_x = op .ScatterElements (flat_x , flat_ind , expanded_values , ** scat_kwargs )
45554561 return op .Reshape (flat_up_x , op .Shape (x ))
0 commit comments