Skip to content

Commit aa60302

Browse files
authored
Add support for Moonshine ASR (#1099)
* Add support for Moonshine ASR * Add ASR pipeline API support for moonshine * Add moonshine feature extractor unit test * Pass moonshine pipeline generation kwargs to generate
1 parent 5334e7e commit aa60302

File tree

10 files changed

+135
-0
lines changed

10 files changed

+135
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
366366
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
367367
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
368368
1. **Moondream1** released in the repository [moondream](https://github.com/vikhyat/moondream) by vikhyat.
369+
1. **[Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine)** (from Useful Sensors) released with the paper [Moonshine: Speech Recognition for Live Transcription and Voice Commands](https://arxiv.org/abs/2410.15608) by Nat Jeffries, Evan King, Manjunath Kudlur, Guy Nicholson, James Wang, Pete Warden.
369370
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
370371
1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaicML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team.
371372
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.

docs/snippets/6_supported-models.snippet

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
8282
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
8383
1. **Moondream1** released in the repository [moondream](https://github.com/vikhyat/moondream) by vikhyat.
84+
1. **[Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine)** (from Useful Sensors) released with the paper [Moonshine: Speech Recognition for Live Transcription and Voice Commands](https://arxiv.org/abs/2410.15608) by Nat Jeffries, Evan King, Manjunath Kudlur, Guy Nicholson, James Wang, Pete Warden.
8485
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
8586
1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaicML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team.
8687
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.

src/configs.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ function getNormalizedConfig(config) {
186186
mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'd_model';
187187
break;
188188
case 'musicgen_decoder':
189+
case 'moonshine':
189190
mapping['num_encoder_layers'] = mapping['num_decoder_layers'] = 'num_hidden_layers';
190191
mapping['num_encoder_heads'] = mapping['num_decoder_heads'] = 'num_attention_heads';
191192
mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'hidden_size';

src/models.js

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3359,6 +3359,29 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
33593359
}
33603360
//////////////////////////////////////////////////
33613361

3362+
3363+
//////////////////////////////////////////////////
3364+
// Moonshine models
3365+
export class MoonshinePreTrainedModel extends PreTrainedModel {
3366+
3367+
requires_attention_mask = false;
3368+
main_input_name = 'input_values';
3369+
forward_params = [
3370+
'input_values',
3371+
'decoder_input_ids',
3372+
'past_key_values',
3373+
];
3374+
};
3375+
3376+
/**
3377+
* MoonshineModel class for training Moonshine models without a language model head.
3378+
*/
3379+
export class MoonshineModel extends MoonshinePreTrainedModel { }
3380+
3381+
export class MoonshineForConditionalGeneration extends MoonshinePreTrainedModel { }
3382+
//////////////////////////////////////////////////
3383+
3384+
33623385
//////////////////////////////////////////////////
33633386
/**
33643387
* Vision Encoder-Decoder model based on OpenAI's GPT architecture for image captioning and other vision tasks
@@ -7013,6 +7036,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
70137036
const MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = new Map([
70147037
['speecht5', ['SpeechT5ForSpeechToText', SpeechT5ForSpeechToText]],
70157038
['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]],
7039+
['moonshine', ['MoonshineForConditionalGeneration', MoonshineForConditionalGeneration]],
70167040
]);
70177041

70187042
const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([

src/models/feature_extractors.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
export * from './audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.js';
33
export * from './clap/feature_extraction_clap.js';
4+
export * from './moonshine/feature_extraction_moonshine.js';
45
export * from './pyannote/feature_extraction_pyannote.js';
56
export * from './seamless_m4t/feature_extraction_seamless_m4t.js';
67
export * from './speecht5/feature_extraction_speecht5.js';
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import { FeatureExtractor, validate_audio_inputs } from '../../base/feature_extraction_utils.js';
2+
import { Tensor } from '../../utils/tensor.js';
3+
4+
5+
export class MoonshineFeatureExtractor extends FeatureExtractor {
6+
/**
7+
* Asynchronously extracts input values from a given audio using the provided configuration.
8+
* @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array.
9+
* @returns {Promise<{ input_values: Tensor; }>} The extracted input values.
10+
*/
11+
async _call(audio) {
12+
validate_audio_inputs(audio, 'MoonshineFeatureExtractor');
13+
14+
if (audio instanceof Float64Array) {
15+
audio = new Float32Array(audio);
16+
}
17+
18+
const shape = [
19+
1, /* batch_size */
20+
audio.length, /* num_samples */
21+
];
22+
return {
23+
input_values: new Tensor('float32', audio, shape),
24+
};
25+
}
26+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import { AutoFeatureExtractor } from "../auto/feature_extraction_auto.js"
2+
import { AutoTokenizer } from "../../tokenizers.js"
3+
import { Processor } from "../../base/processing_utils.js"
4+
5+
/**
6+
* Represents a MoonshineProcessor that extracts features from an audio input.
7+
*/
8+
export class MoonshineProcessor extends Processor {
9+
static tokenizer_class = AutoTokenizer
10+
static feature_extractor_class = AutoFeatureExtractor
11+
12+
/**
13+
* Calls the feature_extractor function with the given audio input.
14+
* @param {any} audio The audio input to extract features from.
15+
* @returns {Promise<any>} A Promise that resolves with the extracted features.
16+
*/
17+
async _call(audio) {
18+
return await this.feature_extractor(audio);
19+
}
20+
}

src/models/processors.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
export * from './florence2/processing_florence2.js';
22
export * from './mgp_str/processing_mgp_str.js';
3+
export * from './moonshine/processing_moonshine.js';
34
export * from './idefics3/processing_idefics3.js';
45
export * from './janus/processing_janus.js';
56
export * from './jina_clip/processing_jina_clip.js';

src/pipelines.js

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,6 +1729,8 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
17291729
case 'unispeech-sat':
17301730
case 'hubert':
17311731
return this._call_wav2vec2(audio, kwargs)
1732+
case 'moonshine':
1733+
return this._call_moonshine(audio, kwargs)
17321734
default:
17331735
throw new Error(`AutomaticSpeechRecognitionPipeline does not support model type '${this.model.config.model_type}'.`)
17341736
}
@@ -1882,6 +1884,34 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
18821884
}
18831885
return single ? toReturn[0] : toReturn;
18841886
}
1887+
1888+
/**
1889+
* @type {AutomaticSpeechRecognitionPipelineCallback}
1890+
* @private
1891+
*/
1892+
async _call_moonshine(audio, kwargs) {
1893+
const single = !Array.isArray(audio);
1894+
if (single) {
1895+
audio = [/** @type {AudioInput} */ (audio)];
1896+
}
1897+
const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
1898+
const preparedAudios = await prepareAudios(audio, sampling_rate);
1899+
const toReturn = [];
1900+
for (const aud of preparedAudios) {
1901+
const inputs = await this.processor(aud);
1902+
1903+
// According to the [paper](https://arxiv.org/pdf/2410.15608):
1904+
// "We use greedy decoding, with a heuristic limit of 6 output tokens
1905+
// per second of audio to avoid repeated output sequences."
1906+
const max_new_tokens = Math.floor(aud.length / sampling_rate) * 6;
1907+
const outputs = await this.model.generate({ max_new_tokens, ...kwargs, ...inputs });
1908+
1909+
const text = this.processor.batch_decode(outputs, { skip_special_tokens: true })[0];
1910+
toReturn.push({ text });
1911+
}
1912+
return single ? toReturn[0] : toReturn;
1913+
}
1914+
18851915
}
18861916

18871917
/**
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import { AutoFeatureExtractor, MoonshineFeatureExtractor } from "../../../src/transformers.js";
2+
3+
import { load_cached_audio } from "../../asset_cache.js";
4+
import { MAX_FEATURE_EXTRACTOR_LOAD_TIME, MAX_TEST_EXECUTION_TIME } from "../../init.js";
5+
6+
export default () => {
7+
// MoonshineFeatureExtractor
8+
describe("MoonshineFeatureExtractor", () => {
9+
const model_id = "onnx-community/moonshine-tiny-ONNX";
10+
11+
/** @type {MoonshineFeatureExtractor} */
12+
let feature_extractor;
13+
beforeAll(async () => {
14+
feature_extractor = await AutoFeatureExtractor.from_pretrained(model_id);
15+
}, MAX_FEATURE_EXTRACTOR_LOAD_TIME);
16+
17+
it(
18+
"default",
19+
async () => {
20+
const audio = await load_cached_audio("mlk");
21+
const { input_values } = await feature_extractor(audio);
22+
expect(input_values.dims).toEqual([1, 208000]);
23+
expect(input_values.mean().item()).toBeCloseTo(-1.5654930507480458e-7, 6);
24+
expect(input_values.data[0]).toBeCloseTo(0.0067138671875, 6);
25+
expect(input_values.data.at(-1)).toBeCloseTo(-0.013427734375, 6);
26+
},
27+
MAX_TEST_EXECUTION_TIME,
28+
);
29+
});
30+
};

0 commit comments

Comments
 (0)