Skip to content

Commit bd5dde4

Browse files
ydwu4amathewc
authored andcommitted
[export] fix stft decomp and making it consistent with cpp impl. (pytorch#149232)
Summary: We change the fake impl of stft to follow more closely with its cpp implementation [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L951-L963) where " n_frames = 1 + (len - n_fft) / hop_length;" is also an integer division. Test Plan: Existing tests and buck2 build --flagfile fbcode//mode/dev fbcode//executorch/examples/models/fb/llama4:speech_transform.pte Differential Revision: D71209142 edit: we kept the original path un-changed. Pull Request resolved: pytorch#149232 Approved by: https://github.com/jackzhxng
1 parent 689dc60 commit bd5dde4

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torch/_refs/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3451,10 +3451,12 @@ def stft(
34513451
left = (n_fft - win_length_) // 2
34523452
window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left])
34533453

3454-
input = input.unfold(dimension=-1, size=n_fft, step=hop_length_)
34553454
if not center and align_to_window:
34563455
input_pad_amount = (n_fft - win_length_) // 2
34573456
input = aten.pad(input, [input_pad_amount, input_pad_amount], pad_mode)
3457+
3458+
input = input.unfold(dimension=-1, size=n_fft, step=hop_length_)
3459+
34583460
if window is not None:
34593461
input = input * window
34603462

0 commit comments

Comments
 (0)