Skip to content

Commit 46c8320

Browse files
committed
Merge text-to-audio pipeline
1 parent e47b59b commit 46c8320

File tree

1 file changed

+88
-34
lines changed

1 file changed

+88
-34
lines changed

src/pipelines/text-to-audio.js

Lines changed: 88 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ import { AutoModel } from '../models.js';
2323
*
2424
* @typedef {Object} TextToAudioPipelineOptions Parameters specific to text-to-audio pipelines.
2525
* @property {Tensor|Float32Array|string|URL} [speaker_embeddings=null] The speaker embeddings (if the model requires it).
26+
* @property {number} [num_inference_steps] The number of denoising steps (if the model supports it).
27+
* More denoising steps usually lead to higher quality audio but slower inference.
28+
* @property {number} [speed] The speed of the generated audio (if the model supports it).
2629
*
2730
* @callback TextToAudioPipelineCallback Generates speech/audio from the inputs.
2831
* @param {string|string[]} texts The text(s) to generate.
@@ -36,31 +39,24 @@ import { AutoModel } from '../models.js';
3639
* Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`.
3740
* This pipeline generates an audio file from an input text and optional other conditional inputs.
3841
*
39-
* **Example:** Generate audio from text with `Xenova/speecht5_tts`.
42+
* **Example:** Generate audio from text with `onnx-community/Supertonic-TTS-ONNX`.
4043
* ```javascript
41-
* const synthesizer = await pipeline('text-to-speech', 'Xenova/speecht5_tts', { quantized: false });
42-
* const speaker_embeddings = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin';
43-
* const out = await synthesizer('Hello, my dog is cute', { speaker_embeddings });
44+
* const synthesizer = await pipeline('text-to-speech', 'onnx-community/Supertonic-TTS-ONNX');
45+
* const speaker_embeddings = 'https://huggingface.co/onnx-community/Supertonic-TTS-ONNX/resolve/main/voices/F1.bin';
46+
* const output = await synthesizer('Hello there, how are you doing?', { speaker_embeddings });
4447
* // RawAudio {
45-
* // audio: Float32Array(26112) [-0.00005657337896991521, 0.00020583874720614403, ...],
46-
* // sampling_rate: 16000
48+
* // audio: Float32Array(95232) [-0.000482565927086398, -0.0004853440332226455, ...],
49+
* // sampling_rate: 44100
4750
* // }
48-
* ```
49-
*
50-
* You can then save the audio to a .wav file with the `wavefile` package:
51-
* ```javascript
52-
* import wavefile from 'wavefile';
53-
* import fs from 'fs';
54-
*
55-
* const wav = new wavefile.WaveFile();
56-
* wav.fromScratch(1, out.sampling_rate, '32f', out.audio);
57-
* fs.writeFileSync('out.wav', wav.toBuffer());
51+
*
52+
* // Optional: Save the audio to a .wav file or Blob
53+
* await output.save('output.wav'); // You can also use `output.toBlob()` to access the audio as a Blob
5854
* ```
5955
*
6056
* **Example:** Multilingual speech generation with `Xenova/mms-tts-fra`. See [here](https://huggingface.co/models?pipeline_tag=text-to-speech&other=vits&sort=trending) for the full list of available languages (1107).
6157
* ```javascript
6258
* const synthesizer = await pipeline('text-to-speech', 'Xenova/mms-tts-fra');
63-
* const out = await synthesizer('Bonjour');
59+
* const output = await synthesizer('Bonjour');
6460
* // RawAudio {
6561
* // audio: Float32Array(23808) [-0.00037693005288019776, 0.0003325853613205254, ...],
6662
* // sampling_rate: 16000
@@ -83,17 +79,78 @@ export class TextToAudioPipeline
8379
this.vocoder = options.vocoder ?? null;
8480
}
8581

82+
async _prepare_speaker_embeddings(speaker_embeddings) {
83+
// Load speaker embeddings as Float32Array from path/URL
84+
if (typeof speaker_embeddings === 'string' || speaker_embeddings instanceof URL) {
85+
// Load from URL with fetch
86+
speaker_embeddings = new Float32Array(
87+
await (await fetch(speaker_embeddings)).arrayBuffer()
88+
);
89+
}
90+
91+
if (speaker_embeddings instanceof Float32Array) {
92+
speaker_embeddings = new Tensor(
93+
'float32',
94+
speaker_embeddings,
95+
[speaker_embeddings.length]
96+
)
97+
} else if (!(speaker_embeddings instanceof Tensor)) {
98+
throw new Error("Speaker embeddings must be a `Tensor`, `Float32Array`, `string`, or `URL`.")
99+
}
100+
101+
return speaker_embeddings;
102+
}
103+
86104
/** @type {TextToAudioPipelineCallback} */
87-
async _call(text_inputs, { speaker_embeddings = null } = {}) {
105+
async _call(text_inputs, {
106+
speaker_embeddings = null,
107+
num_inference_steps,
108+
speed,
109+
} = {}) {
110+
88111
// If this.processor is not set, we are using a `AutoModelForTextToWaveform` model
89112
if (this.processor) {
90113
return this._call_text_to_spectrogram(text_inputs, { speaker_embeddings });
114+
} else if (
115+
this.model.config.model_type === "supertonic"
116+
) {
117+
return this._call_supertonic(text_inputs, { speaker_embeddings, num_inference_steps, speed });
91118
} else {
92119
return this._call_text_to_waveform(text_inputs);
93120
}
94121
}
95122

123+
async _call_supertonic(text_inputs, { speaker_embeddings, num_inference_steps, speed }) {
124+
if (!speaker_embeddings) {
125+
throw new Error("Speaker embeddings must be provided for Supertonic models.");
126+
}
127+
speaker_embeddings = await this._prepare_speaker_embeddings(speaker_embeddings);
128+
129+
// @ts-expect-error TS2339
130+
const { sampling_rate, style_dim } = this.model.config;
131+
132+
speaker_embeddings = (/** @type {Tensor} */ (speaker_embeddings)).view(1, -1, style_dim);
133+
const inputs = this.tokenizer(text_inputs, {
134+
padding: true,
135+
truncation: true,
136+
});
137+
138+
// @ts-expect-error TS2339
139+
const { waveform } = await this.model.generate_speech({
140+
...inputs,
141+
style: speaker_embeddings,
142+
num_inference_steps,
143+
speed,
144+
});
145+
146+
return new RawAudio(
147+
waveform.data,
148+
sampling_rate,
149+
)
150+
}
151+
96152
async _call_text_to_waveform(text_inputs) {
153+
97154
// Run tokenization
98155
const inputs = this.tokenizer(text_inputs, {
99156
padding: true,
@@ -105,39 +162,36 @@ export class TextToAudioPipeline
105162

106163
// @ts-expect-error TS2339
107164
const sampling_rate = this.model.config.sampling_rate;
108-
return new RawAudio(waveform.data, sampling_rate);
165+
return new RawAudio(
166+
waveform.data,
167+
sampling_rate,
168+
)
109169
}
110170

111171
async _call_text_to_spectrogram(text_inputs, { speaker_embeddings }) {
172+
112173
// Load vocoder, if not provided
113174
if (!this.vocoder) {
114175
console.log('No vocoder specified, using default HifiGan vocoder.');
115176
this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { dtype: 'fp32' });
116177
}
117178

118-
// Load speaker embeddings as Float32Array from path/URL
119-
if (typeof speaker_embeddings === 'string' || speaker_embeddings instanceof URL) {
120-
// Load from URL with fetch
121-
speaker_embeddings = new Float32Array(await (await fetch(speaker_embeddings)).arrayBuffer());
122-
}
123-
124-
if (speaker_embeddings instanceof Float32Array) {
125-
speaker_embeddings = new Tensor('float32', speaker_embeddings, [1, speaker_embeddings.length]);
126-
} else if (!(speaker_embeddings instanceof Tensor)) {
127-
throw new Error('Speaker embeddings must be a `Tensor`, `Float32Array`, `string`, or `URL`.');
128-
}
129-
130179
// Run tokenization
131180
const { input_ids } = this.tokenizer(text_inputs, {
132181
padding: true,
133182
truncation: true,
134183
});
135184

136-
// NOTE: At this point, we are guaranteed that `speaker_embeddings` is a `Tensor`
137-
// @ts-ignore
185+
speaker_embeddings = await this._prepare_speaker_embeddings(speaker_embeddings);
186+
speaker_embeddings = speaker_embeddings.view(1, -1);
187+
188+
// @ts-expect-error TS2339
138189
const { waveform } = await this.model.generate_speech(input_ids, speaker_embeddings, { vocoder: this.vocoder });
139190

140191
const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
141-
return new RawAudio(waveform.data, sampling_rate);
192+
return new RawAudio(
193+
waveform.data,
194+
sampling_rate,
195+
)
142196
}
143197
}

0 commit comments

Comments
 (0)