Skip to content

Commit 5658edd

Browse files
authored
Add support for Parakeet CTC (#1440)
* Add support for Parakeet CTC * Add parakeet to list of supported models * Skip special tokens when decoding CTC output from ASR pipeline * Trim decoded output * Update tests
1 parent 85b8eb2 commit 5658edd

File tree

8 files changed

+162
-7
lines changed

8 files changed

+162
-7
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
404404
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://huggingface.co/papers/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
405405
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://huggingface.co/papers/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby.
406406
1. **[PaliGemma](https://huggingface.co/docs/transformers/main/model_doc/paligemma)** (from Google) released with the papers [PaliGemma: A versatile 3B VLM for transfer](https://huggingface.co/papers/2407.07726) and [PaliGemma 2: A Family of Versatile VLMs for Transfer](https://huggingface.co/papers/2412.03555) by the PaliGemma Google team.
407+
1. **[Parakeet](https://huggingface.co/docs/transformers/main/model_doc/parakeet)** (from NVIDIA) released with the blog post [Introducing the Parakeet ASR family](https://developer.nvidia.com/blog/pushing-the-boundaries-of-speech-recognition-with-nemo-parakeet-asr-models/) by the NVIDIA NeMo team.
407408
1. **[PatchTSMixer](https://huggingface.co/docs/transformers/main/model_doc/patchtsmixer)** (from IBM) released with the paper [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://huggingface.co/papers/2306.09364) by Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
408409
1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from Princeton University, IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://huggingface.co/papers/2211.14730) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
409410
1. **[Phi](https://huggingface.co/docs/transformers/main/model_doc/phi)** (from Microsoft) released with the papers - [Textbooks Are All You Need](https://huggingface.co/papers/2306.11644) by Suriya Gunasekar, Yi Zhang, Jyoti Aneja, Caio César Teodoro Mendes, Allie Del Giorno, Sivakanth Gopi, Mojan Javaheripi, Piero Kauffmann, Gustavo de Rosa, Olli Saarikivi, Adil Salim, Shital Shah, Harkirat Singh Behl, Xin Wang, Sébastien Bubeck, Ronen Eldan, Adam Tauman Kalai, Yin Tat Lee and Yuanzhi Li, [Textbooks Are All You Need II: phi-1.5 technical report](https://huggingface.co/papers/2309.05463) by Yuanzhi Li, Sébastien Bubeck, Ronen Eldan, Allie Del Giorno, Suriya Gunasekar and Yin Tat Lee.

docs/snippets/6_supported-models.snippet

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://huggingface.co/papers/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
119119
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://huggingface.co/papers/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby.
120120
1. **[PaliGemma](https://huggingface.co/docs/transformers/main/model_doc/paligemma)** (from Google) released with the papers [PaliGemma: A versatile 3B VLM for transfer](https://huggingface.co/papers/2407.07726) and [PaliGemma 2: A Family of Versatile VLMs for Transfer](https://huggingface.co/papers/2412.03555) by the PaliGemma Google team.
121+
1. **[Parakeet](https://huggingface.co/docs/transformers/main/model_doc/parakeet)** (from NVIDIA) released with the blog post [Introducing the Parakeet ASR family](https://developer.nvidia.com/blog/pushing-the-boundaries-of-speech-recognition-with-nemo-parakeet-asr-models/) by the NVIDIA NeMo team.
121122
1. **[PatchTSMixer](https://huggingface.co/docs/transformers/main/model_doc/patchtsmixer)** (from IBM) released with the paper [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://huggingface.co/papers/2306.09364) by Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
122123
1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from Princeton University, IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://huggingface.co/papers/2211.14730) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
123124
1. **[Phi](https://huggingface.co/docs/transformers/main/model_doc/phi)** (from Microsoft) released with the papers - [Textbooks Are All You Need](https://huggingface.co/papers/2306.11644) by Suriya Gunasekar, Yi Zhang, Jyoti Aneja, Caio César Teodoro Mendes, Allie Del Giorno, Sivakanth Gopi, Mojan Javaheripi, Piero Kauffmann, Gustavo de Rosa, Olli Saarikivi, Adil Salim, Shital Shah, Harkirat Singh Behl, Xin Wang, Sébastien Bubeck, Ronen Eldan, Adam Tauman Kalai, Yin Tat Lee and Yuanzhi Li, [Textbooks Are All You Need II: phi-1.5 technical report](https://huggingface.co/papers/2309.05463) by Yuanzhi Li, Sébastien Bubeck, Ronen Eldan, Allie Del Giorno, Suriya Gunasekar and Yin Tat Lee.

src/models.js

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6161,6 +6161,21 @@ export class Wav2Vec2ForAudioFrameClassification extends Wav2Vec2PreTrainedModel
61616161
}
61626162
//////////////////////////////////////////////////
61636163

6164+
//////////////////////////////////////////////////
6165+
// Parakeet models
6166+
export class ParakeetPreTrainedModel extends PreTrainedModel { };
6167+
export class ParakeetForCTC extends ParakeetPreTrainedModel {
6168+
/**
6169+
* @param {Object} model_inputs
6170+
* @param {Tensor} model_inputs.input_values Float values of input raw speech waveform.
6171+
* @param {Tensor} model_inputs.attention_mask Mask to avoid performing convolution and attention on padding token indices. Mask values selected in [0, 1]
6172+
*/
6173+
async _call(model_inputs) {
6174+
return new CausalLMOutput(await super._call(model_inputs));
6175+
}
6176+
}
6177+
//////////////////////////////////////////////////
6178+
61646179

61656180
//////////////////////////////////////////////////
61666181
// PyAnnote models
@@ -8148,6 +8163,7 @@ const MODEL_FOR_CTC_MAPPING_NAMES = new Map([
81488163
['unispeech-sat', ['UniSpeechSatForCTC', UniSpeechSatForCTC]],
81498164
['wavlm', ['WavLMForCTC', WavLMForCTC]],
81508165
['hubert', ['HubertForCTC', HubertForCTC]],
8166+
['parakeet_ctc', ['ParakeetForCTC', ParakeetForCTC]],
81518167
]);
81528168

81538169
const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([

src/models/feature_extractors.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ export * from './clap/feature_extraction_clap.js';
55
export * from './dac/feature_extraction_dac.js';
66
export * from './gemma3n/feature_extraction_gemma3n.js';
77
export * from './moonshine/feature_extraction_moonshine.js';
8+
export * from './parakeet/feature_extraction_parakeet.js';
89
export * from './pyannote/feature_extraction_pyannote.js';
910
export * from './seamless_m4t/feature_extraction_seamless_m4t.js';
1011
export * from './snac/feature_extraction_snac.js';
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import { FeatureExtractor, validate_audio_inputs } from '../../base/feature_extraction_utils.js';
2+
import { Tensor } from '../../utils/tensor.js';
3+
import { mel_filter_bank, spectrogram, window_function } from '../../utils/audio.js';
4+
5+
const EPSILON = 1e-5;
6+
7+
export class ParakeetFeatureExtractor extends FeatureExtractor {
8+
9+
constructor(config) {
10+
super(config);
11+
12+
// Prefer given `mel_filters` from preprocessor_config.json, or calculate them if they don't exist.
13+
this.config.mel_filters ??= mel_filter_bank(
14+
Math.floor(1 + this.config.n_fft / 2), // num_frequency_bins
15+
this.config.feature_size, // num_mel_filters
16+
0.0, // min_frequency
17+
this.config.sampling_rate / 2, // max_frequency
18+
this.config.sampling_rate, // sampling_rate
19+
"slaney", // norm
20+
"slaney", // mel_scale
21+
);
22+
23+
const window = window_function(this.config.win_length, 'hann', {
24+
periodic: false,
25+
});
26+
27+
this.window = new Float64Array(this.config.n_fft);
28+
const offset = Math.floor((this.config.n_fft - this.config.win_length) / 2);
29+
this.window.set(window, offset);
30+
}
31+
32+
/**
33+
* Computes the log-Mel spectrogram of the provided audio waveform.
34+
* @param {Float32Array|Float64Array} waveform The audio waveform to process.
35+
* @returns {Promise<Tensor>} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
36+
*/
37+
async _extract_fbank_features(waveform) {
38+
// Parakeet uses a custom preemphasis strategy: Apply preemphasis to entire waveform at once
39+
const preemphasis = this.config.preemphasis;
40+
waveform = new Float64Array(waveform); // Clone to avoid destructive changes
41+
for (let j = waveform.length - 1; j >= 1; --j) {
42+
waveform[j] -= preemphasis * waveform[j - 1];
43+
}
44+
45+
const features = await spectrogram(
46+
waveform,
47+
this.window, // window
48+
this.window.length, // frame_length
49+
this.config.hop_length, // hop_length
50+
{
51+
fft_length: this.config.n_fft,
52+
power: 2.0,
53+
mel_filters: this.config.mel_filters,
54+
log_mel: 'log',
55+
mel_floor: -Infinity,
56+
pad_mode: 'constant',
57+
center: true,
58+
59+
// Custom
60+
transpose: true,
61+
mel_offset: 2 ** -24,
62+
}
63+
)
64+
65+
return features;
66+
}
67+
68+
/**
69+
* Asynchronously extracts features from a given audio using the provided configuration.
70+
* @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array.
71+
* @returns {Promise<{ input_features: Tensor; attention_mask: Tensor; }>} A Promise resolving to an object containing the extracted input features as a Tensor.
72+
*/
73+
async _call(audio) {
74+
validate_audio_inputs(audio, 'ParakeetFeatureExtractor');
75+
76+
const features = await this._extract_fbank_features(audio);
77+
78+
const features_length = Math.floor(
79+
(audio.length + Math.floor(this.config.n_fft / 2) * 2 - this.config.n_fft) / this.config.hop_length
80+
);
81+
82+
const features_data = /** @type {Float32Array} */ (features.data);
83+
features_data.fill(0, features_length * features.dims[1]);
84+
85+
// normalize mel features, ignoring padding
86+
const [num_frames, num_features] = features.dims;
87+
const sum = new Float64Array(num_features);
88+
const sum_sq = new Float64Array(num_features);
89+
90+
for (let i = 0; i < features_length; ++i) {
91+
const offset = i * num_features;
92+
for (let j = 0; j < num_features; ++j) {
93+
const val = features_data[offset + j];
94+
sum[j] += val;
95+
sum_sq[j] += val * val;
96+
}
97+
}
98+
99+
// Calculate mean and standard deviation, then normalize
100+
const divisor = features_length > 1 ? features_length - 1 : 1;
101+
for (let j = 0; j < num_features; ++j) {
102+
const mean = sum[j] / features_length;
103+
const variance = (sum_sq[j] - features_length * mean * mean) / divisor;
104+
const std = Math.sqrt(variance) + EPSILON;
105+
const inv_std = 1 / std;
106+
107+
for (let i = 0; i < features_length; ++i) {
108+
const index = i * num_features + j;
109+
features_data[index] = (features_data[index] - mean) * inv_std;
110+
}
111+
}
112+
113+
const mask_data = new BigInt64Array(num_frames);
114+
mask_data.fill(1n, 0, features_length);
115+
116+
return {
117+
input_features: features.unsqueeze_(0),
118+
attention_mask: new Tensor('int64', mask_data, [1, num_frames]),
119+
};
120+
}
121+
}

src/pipelines.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1750,6 +1750,7 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
17501750
case 'unispeech':
17511751
case 'unispeech-sat':
17521752
case 'hubert':
1753+
case 'parakeet_ctc':
17531754
return this._call_wav2vec2(audio, kwargs)
17541755
case 'moonshine':
17551756
return this._call_moonshine(audio, kwargs)
@@ -1790,7 +1791,7 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
17901791
for (const item of logits) {
17911792
predicted_ids.push(max(item.data)[1])
17921793
}
1793-
const predicted_sentences = this.tokenizer.decode(predicted_ids)
1794+
const predicted_sentences = this.tokenizer.decode(predicted_ids, { skip_special_tokens: true }).trim();
17941795
toReturn.push({ text: predicted_sentences })
17951796
}
17961797
return single ? toReturn[0] : toReturn;

src/utils/audio.js

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ function power_to_db(spectrogram, reference = 1.0, min_value = 1e-10, db_range =
470470
* @param {number} [options.min_num_frames=null] If provided, ensures the number of frames to compute is at least this value.
471471
* @param {boolean} [options.do_pad=true] If `true`, pads the output spectrogram to have `max_num_frames` frames.
472472
* @param {boolean} [options.transpose=false] If `true`, the returned spectrogram will have shape `(num_frames, num_frequency_bins/num_mel_filters)`. If `false`, the returned spectrogram will have shape `(num_frequency_bins/num_mel_filters, num_frames)`.
473+
* @param {number} [options.mel_offset=0] Offset to add to the mel spectrogram to avoid taking the log of zero.
473474
* @returns {Promise<Tensor>} Spectrogram of shape `(num_frequency_bins, length)` (regular spectrogram) or shape `(num_mel_filters, length)` (mel spectrogram).
474475
*/
475476
export async function spectrogram(
@@ -498,6 +499,7 @@ export async function spectrogram(
498499
max_num_frames = null,
499500
do_pad = true,
500501
transpose = false,
502+
mel_offset = 0,
501503
} = {}
502504
) {
503505
const window_length = window.length;
@@ -530,11 +532,23 @@ export async function spectrogram(
530532
}
531533

532534
if (center) {
533-
if (pad_mode !== 'reflect') {
534-
throw new Error(`pad_mode="${pad_mode}" not implemented yet.`)
535+
switch (pad_mode) {
536+
case 'reflect': {
537+
const half_window = Math.floor((fft_length - 1) / 2) + 1;
538+
waveform = padReflect(waveform, half_window, half_window);
539+
break;
540+
}
541+
case 'constant': {
542+
const padding = Math.floor(fft_length / 2);
543+
// @ts-expect-error ts(2351)
544+
const padded = new waveform.constructor(waveform.length + 2 * padding);
545+
padded.set(waveform, padding);
546+
waveform = padded;
547+
break;
548+
}
549+
default:
550+
throw new Error(`pad_mode="${pad_mode}" not implemented yet.`);
535551
}
536-
const half_window = Math.floor((fft_length - 1) / 2) + 1;
537-
waveform = padReflect(waveform, half_window, half_window);
538552
}
539553

540554
// split waveform into frames of frame_length size
@@ -641,7 +655,7 @@ export async function spectrogram(
641655

642656
const mel_spec_data = /** @type {Float32Array} */(mel_spec.data);
643657
for (let i = 0; i < mel_spec_data.length; ++i) {
644-
mel_spec_data[i] = Math.max(mel_floor, mel_spec_data[i]);
658+
mel_spec_data[i] = mel_offset + Math.max(mel_floor, mel_spec_data[i]);
645659
}
646660

647661
if (power !== null && log_mel !== null) {

tests/pipelines/test_pipelines_automatic_speech_recognition.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ export default () => {
114114
"default",
115115
async () => {
116116
const output = await pipe(audios[0], { max_new_tokens });
117-
const target = { text: "<unk>K" };
117+
const target = { text: "K" };
118118
expect(output).toEqual(target);
119119
},
120120
MAX_TEST_EXECUTION_TIME,

0 commit comments

Comments
 (0)