Skip to content

Commit 8d166ca

Browse files
authored
Optimize FFT (#766)
* Optimize FFT for real transforms * Throw error if power is not specified huggingface/transformers#27772
1 parent 8963720 commit 8d166ca

File tree

2 files changed

+60
-42
lines changed

2 files changed

+60
-42
lines changed

src/utils/audio.js

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,13 @@ export function spectrogram(
473473
throw new Error("hop_length must be greater than zero");
474474
}
475475

476+
if (power === null && mel_filters !== null) {
477+
throw new Error(
478+
"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram. " +
479+
"Specify `power` to fix this issue."
480+
);
481+
}
482+
476483
if (center) {
477484
if (pad_mode !== 'reflect') {
478485
throw new Error(`pad_mode="${pad_mode}" not implemented yet.`)
@@ -547,8 +554,6 @@ export function spectrogram(
547554
magnitudes[i] = row;
548555
}
549556

550-
// TODO what should happen if power is None?
551-
// https://github.com/huggingface/transformers/issues/27772
552557
if (power !== null && power !== 2) {
553558
// slight optimization to not sqrt
554559
const pow = 2 / power; // we use 2 since we already squared

src/utils/maths.js

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -364,20 +364,6 @@ class P2FFT {
364364
return res;
365365
}
366366

367-
/**
368-
* Completes the spectrum by adding its mirrored negative frequency components.
369-
* @param {Float64Array} spectrum The input spectrum.
370-
* @returns {void}
371-
*/
372-
completeSpectrum(spectrum) {
373-
const size = this._csize;
374-
const half = size >>> 1;
375-
for (let i = 2; i < half; i += 2) {
376-
spectrum[size - i] = spectrum[i];
377-
spectrum[size - i + 1] = -spectrum[i + 1];
378-
}
379-
}
380-
381367
/**
382368
* Performs a Fast Fourier Transform (FFT) on the given input data and stores the result in the output buffer.
383369
*
@@ -466,6 +452,7 @@ class P2FFT {
466452
}
467453

468454
// Loop through steps in decreasing order
455+
const table = this.table;
469456
for (step >>= 2; step >= 2; step >>= 2) {
470457
len = (size / step) << 1;
471458
const quarterLen = len >>> 2;
@@ -490,18 +477,18 @@ class P2FFT {
490477
const Dr = out[D];
491478
const Di = out[D + 1];
492479

493-
const tableBr = this.table[k];
494-
const tableBi = inv * this.table[k + 1];
480+
const tableBr = table[k];
481+
const tableBi = inv * table[k + 1];
495482
const MBr = Br * tableBr - Bi * tableBi;
496483
const MBi = Br * tableBi + Bi * tableBr;
497484

498-
const tableCr = this.table[2 * k];
499-
const tableCi = inv * this.table[2 * k + 1];
485+
const tableCr = table[2 * k];
486+
const tableCi = inv * table[2 * k + 1];
500487
const MCr = Cr * tableCr - Ci * tableCi;
501488
const MCi = Cr * tableCi + Ci * tableCr;
502489

503-
const tableDr = this.table[3 * k];
504-
const tableDi = inv * this.table[3 * k + 1];
490+
const tableDr = table[3 * k];
491+
const tableDi = inv * table[3 * k + 1];
505492
const MDr = Dr * tableDr - Di * tableDi;
506493
const MDi = Dr * tableDi + Di * tableDr;
507494

@@ -634,18 +621,18 @@ class P2FFT {
634621
}
635622
}
636623

637-
// TODO: Optimize once https://github.com/indutny/fft.js/issues/25 is fixed
638624
// Loop through steps in decreasing order
625+
const table = this.table;
639626
for (step >>= 2; step >= 2; step >>= 2) {
640627
len = (size / step) << 1;
641-
const quarterLen = len >>> 2;
628+
const halfLen = len >>> 1;
629+
const quarterLen = halfLen >>> 1;
630+
const hquarterLen = quarterLen >>> 1;
642631

643632
// Loop through offsets in the data
644633
for (outOff = 0; outOff < size; outOff += len) {
645-
// Full case
646-
const limit = outOff + quarterLen - 1;
647-
for (let i = outOff, k = 0; i < limit; i += 2, k += step) {
648-
const A = i;
634+
for (let i = 0, k = 0; i <= hquarterLen; i += 2, k += step) {
635+
const A = outOff + i;
649636
const B = A + quarterLen;
650637
const C = B + quarterLen;
651638
const D = C + quarterLen;
@@ -660,26 +647,30 @@ class P2FFT {
660647
const Dr = out[D];
661648
const Di = out[D + 1];
662649

663-
const tableBr = this.table[k];
664-
const tableBi = inv * this.table[k + 1];
650+
// Middle values
651+
const MAr = Ar;
652+
const MAi = Ai;
653+
654+
const tableBr = table[k];
655+
const tableBi = inv * table[k + 1];
665656
const MBr = Br * tableBr - Bi * tableBi;
666657
const MBi = Br * tableBi + Bi * tableBr;
667658

668-
const tableCr = this.table[2 * k];
669-
const tableCi = inv * this.table[2 * k + 1];
659+
const tableCr = table[2 * k];
660+
const tableCi = inv * table[2 * k + 1];
670661
const MCr = Cr * tableCr - Ci * tableCi;
671662
const MCi = Cr * tableCi + Ci * tableCr;
672663

673-
const tableDr = this.table[3 * k];
674-
const tableDi = inv * this.table[3 * k + 1];
664+
const tableDr = table[3 * k];
665+
const tableDi = inv * table[3 * k + 1];
675666
const MDr = Dr * tableDr - Di * tableDi;
676667
const MDi = Dr * tableDi + Di * tableDr;
677668

678669
// Pre-Final values
679-
const T0r = Ar + MCr;
680-
const T0i = Ai + MCi;
681-
const T1r = Ar - MCr;
682-
const T1i = Ai - MCi;
670+
const T0r = MAr + MCr;
671+
const T0i = MAi + MCi;
672+
const T1r = MAr - MCr;
673+
const T1i = MAi - MCi;
683674
const T2r = MBr + MDr;
684675
const T2i = MBi + MDi;
685676
const T3r = inv * (MBr - MDr);
@@ -690,13 +681,35 @@ class P2FFT {
690681
out[A + 1] = T0i + T2i;
691682
out[B] = T1r + T3i;
692683
out[B + 1] = T1i - T3r;
693-
out[C] = T0r - T2r;
694-
out[C + 1] = T0i - T2i;
695-
out[D] = T1r - T3i;
696-
out[D + 1] = T1i + T3r;
684+
685+
// Output final middle point
686+
if (i === 0) {
687+
out[C] = T0r - T2r;
688+
out[C + 1] = T0i - T2i;
689+
continue;
690+
}
691+
692+
// Do not overwrite ourselves
693+
if (i === hquarterLen)
694+
continue;
695+
696+
const SA = outOff + quarterLen - i;
697+
const SB = outOff + halfLen - i;
698+
699+
out[SA] = T1r - inv * T3i;
700+
out[SA + 1] = -T1i - inv * T3r;
701+
out[SB] = T0r - inv * T2r;
702+
out[SB + 1] = -T0i + inv * T2i;
697703
}
698704
}
699705
}
706+
707+
// Complete the spectrum by adding its mirrored negative frequency components.
708+
const half = size >>> 1;
709+
for (let i = 2; i < half; i += 2) {
710+
out[size - i] = out[i];
711+
out[size - i + 1] = -out[i + 1];
712+
}
700713
}
701714

702715
/**

0 commit comments

Comments
 (0)