Skip to content

Commit 449e1fe

Browse files
committed
Fix aten_stft
1 parent 085401d commit 449e1fe

File tree

1 file changed

+14
-14
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+14
-14
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8119,8 +8119,8 @@ def aten_std_mean_correction(
81198119

81208120

81218121
@torch_op("aten::stft", private=True)
8122-
def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]:
8123-
signal_rank = Rank(self)
8122+
def _add_batch_dimension(self: TFloat) -> Tuple[TFloat, INT64]:
8123+
signal_rank = op.Size(op.Shape(self))
81248124
if signal_rank == 1:
81258125
# Add a batch dimension
81268126
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
@@ -8129,8 +8129,8 @@ def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT6
81298129

81308130
@torch_op("aten::stft", private=True)
81318131
def _center_window_around_zeros_if_needed(
8132-
window: TFloatOrBFloat16, n_fft: int
8133-
) -> TFloatOrBFloat16:
8132+
window: TFloat, n_fft: int
8133+
) -> TFloat:
81348134
# first dimension
81358135
n_win = op.Shape(window, start=0, end=1)
81368136
# Center window around zeros if needed (required by ONNX's STFT)
@@ -8150,7 +8150,7 @@ def _center_window_around_zeros_if_needed(
81508150

81518151

81528152
@torch_op("aten::stft", private=True)
8153-
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16:
8153+
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat:
81548154
left = (n_fft - win_length) / 2
81558155

81568156
right = n_fft - left - win_length
@@ -8165,16 +8165,16 @@ def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloa
81658165

81668166

81678167
@torch_op("aten::stft", private=True)
8168-
def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16:
8168+
def _create_window_from_n_fft(n_fft: int) -> TFloat:
81698169
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
81708170
window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
81718171
return window
81728172

81738173

81748174
@torch_op("aten::stft", private=True)
81758175
def _normalize_fft_result(
8176-
signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int
8177-
) -> TFloatOrBFloat16:
8176+
signal: TFloat, result: TFloat, n_fft: int
8177+
) -> TFloat:
81788178
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
81798179
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
81808180
result = result / sqrt_nfft
@@ -8183,13 +8183,13 @@ def _normalize_fft_result(
81838183

81848184
@torch_op("aten::stft", private=True)
81858185
def _aten_stft_onnx(
8186-
signal: TFloatOrBFloat16,
8186+
signal: TFloat,
81878187
frame_step_const: INT64,
8188-
window: Union[TFloatOrBFloat16, INT64],
8188+
window: Union[TFloat, INT64],
81898189
frame_length_const: INT64,
81908190
signal_rank: INT64,
81918191
onesided: int,
8192-
) -> TFloatOrBFloat16:
8192+
) -> TFloat:
81938193
window = op.CastLike(window, signal)
81948194
result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)
81958195
result = op.Transpose(result, perm=[0, 2, 1, 3])
@@ -8201,15 +8201,15 @@ def _aten_stft_onnx(
82018201

82028202
@torch_op("aten::stft", trace_only=True)
82038203
def aten_stft(
8204-
self: TFloatOrBFloat16,
8204+
self: TFloat,
82058205
n_fft: int,
82068206
hop_length: Optional[int] = None,
82078207
win_length: Optional[int] = None,
8208-
window: Optional[TFloatOrBFloat16] = None,
8208+
window: Optional[TFloat] = None,
82098209
normalized: bool = False,
82108210
onesided: Optional[bool] = None,
82118211
return_complex: Optional[bool] = None,
8212-
) -> TFloatOrBFloat16:
8212+
) -> TFloat:
82138213
"""stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
82148214

82158215
# NOTE: regarless of the value of return_complex, we always return a real representation.

0 commit comments

Comments
 (0)