@@ -7223,11 +7223,106 @@ def aten_rnn_tanh_cell(
72237223 raise NotImplementedError ()
72247224
72257225
7226- # roll is decomposed by PyTorch
7226+ @ torch_op ( "aten:: roll" , trace_only = True )
72277227def aten_roll (self : TTensor , shifts : Sequence [int ], dims : Sequence [int ] = ()) -> TTensor :
72287228 """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"""
72297229
7230- raise NotImplementedError ()
7230+ if isinstance (shifts , int ):
7231+ shifts = [shifts ]
7232+
7233+ if isinstance (dims , int ):
7234+ dims = [dims ]
7235+
7236+ self_rank = len (self .shape )
7237+ if self_rank == 0 :
7238+ return op .Identity (self )
7239+ elif self .shape [0 ] == 0 : # empty tensor
7240+ return op .Identity (self )
7241+
7242+ # NOTE: In pytorch, default value of dims is an empty list.
7243+ if len (dims ) == 0 : # Empty sequence
7244+ assert len (shifts ) == 1 , "shifts should be a single integer if dims is empty"
7245+ return _aten_roll_shift_no_dim_onnx (self , shifts [0 ])
7246+ else :
7247+ assert len (shifts ) == len (dims )
7248+ result = self
7249+ for i , shift in enumerate (shifts ):
7250+ dim = dims [i ]
7251+ result = _aten_roll_shift_and_dim_onnx (result , shift , dim )
7252+ return result
7253+
7254+
7255+ @torch_op ("aten::roll" , trace_only = True , complex = True )
7256+ def aten_roll_complex (
7257+ self : TTensor , shifts : Sequence [int ], dims : Sequence [int ] = ()
7258+ ) -> TTensor :
7259+ """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"""
7260+
7261+ if isinstance (shifts , int ):
7262+ shifts = [shifts ]
7263+
7264+ if isinstance (dims , int ):
7265+ dims = [dims ]
7266+
7267+ self_rank = len (self .shape )
7268+ if self_rank == 1 :
7269+ return op .Identity (self )
7270+
7271+ if self .shape [0 ] == 0 : # empty tensor
7272+ return op .Identity (self )
7273+
7274+ self_real = op .Slice (self , [0 ], [1 ], axes = [- 1 ])
7275+ self_imag = op .Slice (self , [1 ], [2 ], axes = [- 1 ])
7276+ if not dims :
7277+ assert len (shifts ) == 1 , "shifts should be a single integer if dims is empty"
7278+ shift_real = _aten_roll_shift_no_dim_onnx (self_real , shifts [0 ])
7279+ shift_imag = _aten_roll_shift_no_dim_onnx (self_imag , shifts [0 ])
7280+
7281+ result = op .Concat (shift_real , shift_imag , axis = - 1 )
7282+
7283+ else :
7284+ assert len (shifts ) == len (dims )
7285+ for i , dim in enumerate (dims ):
7286+ self_real = _aten_roll_shift_and_dim_onnx (self_real , shifts [i ], dim )
7287+ self_imag = _aten_roll_shift_and_dim_onnx (self_imag , shifts [i ], dim )
7288+
7289+ result = op .Concat (self_real , self_imag , axis = - 1 )
7290+ return result
7291+
7292+
7293+ def _aten_roll_shift_no_dim_onnx (self : TTensor , shift : int ) -> TTensor :
7294+ neg_1 = op .Constant (value_ints = [- 1 ])
7295+ # flatten the self tensor: from [[A,B],[C,D]] to [A,B,C,D]
7296+ self_flatten = op .Reshape (self , neg_1 )
7297+ # Compute slice length
7298+ if shift < 0 :
7299+ # For [A,B,C,D], if shift is -1, slice_length = -(-1) = 1, means move [A] to the end
7300+ slice_length = op .Constant (value_ints = [- shift ])
7301+ else :
7302+ # For [A,B,C,D], if shift is 1, slice_length = 4 - 1 = 3, means move [A,B,C] to the end
7303+ # The effect equals to move [D] to the beginning
7304+ slice_length = op .Size (self_flatten ) - op .Constant (value_ints = [shift ])
7305+ # Get second part of the tensor, e.g. [A,B,C]
7306+ suffix = op .Slice (self_flatten , op .Constant (value_ints = [0 ]), slice_length )
7307+ # Get first part of the tensor, e.g. [D]
7308+ prefix = op .Slice (self_flatten , slice_length , op .Reshape (op .Size (self_flatten ), neg_1 ))
7309+ # Concat first+second together, e.g. [D,A,B,C]
7310+ result = op .Concat (prefix , suffix , axis = 0 )
7311+ return op .Reshape (result , op .Shape (self ))
7312+
7313+
7314+ def _aten_roll_shift_and_dim_onnx (self : TTensor , shift : int , dim : int ) -> TTensor :
7315+ neg_1 = op .Constant (value_ints = [- 1 ])
7316+ dim_tensor = op .Constant (value_ints = [dim ])
7317+ if shift < 0 :
7318+ slice_length = op .Constant (value_ints = [- shift ])
7319+ else :
7320+ slice_length = op .Shape (self , start = dim , end = dim + 1 ) - op .Constant (value_ints = [shift ])
7321+ # from [A,B,C,D] -> [D,A,B,C], [D] is prefix, [A,B,C] is suffix
7322+ suffix = op .Slice (self , op .Constant (value_ints = [0 ]), slice_length , axes = dim_tensor )
7323+ prefix = op .Slice (self , slice_length , op .Reshape (op .Size (self ), neg_1 ), axes = dim_tensor )
7324+ result = op .Concat (prefix , suffix , axis = dim )
7325+ return result
72317326
72327327
72337328def aten_rot90 (self : TensorType , k : int = 1 , dims : Sequence [int ] = (0 , 1 )) -> TensorType :
0 commit comments