@@ -8119,8 +8119,8 @@ def aten_std_mean_correction(
81198119
81208120
81218121@torch_op ("aten::stft" , private = True )
8122- def _add_batch_dimension (self : TFloatOrBFloat16 ) -> Tuple [TFloatOrBFloat16 , INT64 ]:
8123- signal_rank = Rank ( self )
8122+ def _add_batch_dimension (self : TFloat ) -> Tuple [TFloat , INT64 ]:
8123+ signal_rank = op . Size ( op . Shape ( self ) )
81248124 if signal_rank == 1 :
81258125 # Add a batch dimension
81268126 self = op .Unsqueeze (self , op .Constant (value_ints = [0 ]))
@@ -8129,8 +8129,8 @@ def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT6
81298129
81308130@torch_op ("aten::stft" , private = True )
81318131def _center_window_around_zeros_if_needed (
8132- window : TFloatOrBFloat16 , n_fft : int
8133- ) -> TFloatOrBFloat16 :
8132+ window : TFloat , n_fft : int
8133+ ) -> TFloat :
81348134 # first dimension
81358135 n_win = op .Shape (window , start = 0 , end = 1 )
81368136 # Center window around zeros if needed (required by ONNX's STFT)
@@ -8150,7 +8150,7 @@ def _center_window_around_zeros_if_needed(
81508150
81518151
81528152@torch_op ("aten::stft" , private = True )
8153- def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloatOrBFloat16 :
8153+ def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloat :
81548154 left = (n_fft - win_length ) / 2
81558155
81568156 right = n_fft - left - win_length
@@ -8165,16 +8165,16 @@ def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloa
81658165
81668166
81678167@torch_op ("aten::stft" , private = True )
8168- def _create_window_from_n_fft (n_fft : int ) -> TFloatOrBFloat16 :
8168+ def _create_window_from_n_fft (n_fft : int ) -> TFloat :
81698169 n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
81708170 window = op .Expand (op .Constant (value_ints = [1 ]), n_fft_tensor )
81718171 return window
81728172
81738173
81748174@torch_op ("aten::stft" , private = True )
81758175def _normalize_fft_result (
8176- signal : TFloatOrBFloat16 , result : TFloatOrBFloat16 , n_fft : int
8177- ) -> TFloatOrBFloat16 :
8176+ signal : TFloat , result : TFloat , n_fft : int
8177+ ) -> TFloat :
81788178 n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
81798179 sqrt_nfft = op .Sqrt (op .CastLike (n_fft_tensor , signal ))
81808180 result = result / sqrt_nfft
@@ -8183,13 +8183,13 @@ def _normalize_fft_result(
81838183
81848184@torch_op ("aten::stft" , private = True )
81858185def _aten_stft_onnx (
8186- signal : TFloatOrBFloat16 ,
8186+ signal : TFloat ,
81878187 frame_step_const : INT64 ,
8188- window : Union [TFloatOrBFloat16 , INT64 ],
8188+ window : Union [TFloat , INT64 ],
81898189 frame_length_const : INT64 ,
81908190 signal_rank : INT64 ,
81918191 onesided : int ,
8192- ) -> TFloatOrBFloat16 :
8192+ ) -> TFloat :
81938193 window = op .CastLike (window , signal )
81948194 result = op .STFT (signal , frame_step_const , window , frame_length_const , onesided = onesided )
81958195 result = op .Transpose (result , perm = [0 , 2 , 1 , 3 ])
@@ -8201,15 +8201,15 @@ def _aten_stft_onnx(
82018201
82028202@torch_op ("aten::stft" , trace_only = True )
82038203def aten_stft (
8204- self : TFloatOrBFloat16 ,
8204+ self : TFloat ,
82058205 n_fft : int ,
82068206 hop_length : Optional [int ] = None ,
82078207 win_length : Optional [int ] = None ,
8208- window : Optional [TFloatOrBFloat16 ] = None ,
8208+ window : Optional [TFloat ] = None ,
82098209 normalized : bool = False ,
82108210 onesided : Optional [bool ] = None ,
82118211 return_complex : Optional [bool ] = None ,
8212- ) -> TFloatOrBFloat16 :
8212+ ) -> TFloat :
82138213 """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
82148214
82158215 # NOTE: regarless of the value of return_complex, we always return a real representation.
0 commit comments