Skip to content

Commit 9e3c586

Browse files
authored
Add support for Supertonic TTS (#1459)
* Add support for FixedLength pre-tokenizer * Create randn tensor function * Add support for Supertonic TTS models * Update TTS JSDoc * Add supertonic speed parameter * Update list of supported models
1 parent 8337acc commit 9e3c586

File tree

6 files changed

+199
-40
lines changed

6 files changed

+199
-40
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,8 @@ You can refine your search by selecting the task you're interested in (e.g., [te
432432
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://huggingface.co/papers/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
433433
1. **[StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm)** (from Stability AI) released with the paper [StableLM 3B 4E1T (Technical Report)](https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo) by Jonathan Tow, Marco Bellagente, Dakota Mahan, Carlos Riquelme Ruiz, Duy Phung, Maksym Zhuravinskyi, Nathan Cooper, Nikhil Pinnaparaju, Reshinth Adithyan, and James Baicoianu.
434434
1. **[Starcoder2](https://huggingface.co/docs/transformers/main/model_doc/starcoder2)** (from BigCode team) released with the paper [StarCoder 2 and The Stack v2: The Next Generation](https://huggingface.co/papers/2402.19173) by Anton Lozhkov, Raymond Li, Loubna Ben Allal, Federico Cassano, Joel Lamy-Poirier, Nouamane Tazi, Ao Tang, Dmytro Pykhtar, Jiawei Liu, Yuxiang Wei, Tianyang Liu, Max Tian, Denis Kocetkov, Arthur Zucker, Younes Belkada, Zijian Wang, Qian Liu, Dmitry Abulkhanov, Indraneil Paul, Zhuang Li, Wen-Ding Li, Megan Risdal, Jia Li, Jian Zhu, Terry Yue Zhuo, Evgenii Zheltonozhskii, Nii Osae Osae Dade, Wenhao Yu, Lucas Krauß, Naman Jain, Yixuan Su, Xuanli He, Manan Dey, Edoardo Abati, Yekun Chai, Niklas Muennighoff, Xiangru Tang, Muhtasham Oblokulov, Christopher Akiki, Marc Marone, Chenghao Mou, Mayank Mishra, Alex Gu, Binyuan Hui, Tri Dao, Armel Zebaze, Olivier Dehaene, Nicolas Patry, Canwen Xu, Julian McAuley, Han Hu, Torsten Scholak, Sebastien Paquet, Jennifer Robinson, Carolyn Jane Anderson, Nicolas Chapados, Mostofa Patwary, Nima Tajbakhsh, Yacine Jernite, Carlos Muñoz Ferrandis, Lingming Zhang, Sean Hughes, Thomas Wolf, Arjun Guha, Leandro von Werra, and Harm de Vries.
435-
1. StyleTTS 2 (from Columbia University) released with the paper [StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models](https://huggingface.co/papers/2306.07691) by Yinghao Aaron Li, Cong Han, Vinay S. Raghavan, Gavin Mischler, Nima Mesgarani.
435+
1. **StyleTTS 2** (from Columbia University) released with the paper [StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models](https://huggingface.co/papers/2306.07691) by Yinghao Aaron Li, Cong Han, Vinay S. Raghavan, Gavin Mischler, Nima Mesgarani.
436+
1. **Supertonic** (from Supertone) released with the paper [Training Flow Matching Models with Reliable Labels via Self-Purification](https://huggingface.co/papers/2509.19091) by Hyeongju Kim, Yechan Yu, June Young Yi, Juheon Lee.
436437
1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://huggingface.co/papers/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
437438
1. **[Swin2SR](https://huggingface.co/docs/transformers/model_doc/swin2sr)** (from University of Würzburg) released with the paper [Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration](https://huggingface.co/papers/2209.11345) by Marcos V. Conde, Ui-Jin Choi, Maxime Burchi, Radu Timofte.
438439
1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://huggingface.co/papers/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.

docs/snippets/6_supported-models.snippet

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@
146146
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://huggingface.co/papers/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
147147
1. **[StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm)** (from Stability AI) released with the paper [StableLM 3B 4E1T (Technical Report)](https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo) by Jonathan Tow, Marco Bellagente, Dakota Mahan, Carlos Riquelme Ruiz, Duy Phung, Maksym Zhuravinskyi, Nathan Cooper, Nikhil Pinnaparaju, Reshinth Adithyan, and James Baicoianu.
148148
1. **[Starcoder2](https://huggingface.co/docs/transformers/main/model_doc/starcoder2)** (from BigCode team) released with the paper [StarCoder 2 and The Stack v2: The Next Generation](https://huggingface.co/papers/2402.19173) by Anton Lozhkov, Raymond Li, Loubna Ben Allal, Federico Cassano, Joel Lamy-Poirier, Nouamane Tazi, Ao Tang, Dmytro Pykhtar, Jiawei Liu, Yuxiang Wei, Tianyang Liu, Max Tian, Denis Kocetkov, Arthur Zucker, Younes Belkada, Zijian Wang, Qian Liu, Dmitry Abulkhanov, Indraneil Paul, Zhuang Li, Wen-Ding Li, Megan Risdal, Jia Li, Jian Zhu, Terry Yue Zhuo, Evgenii Zheltonozhskii, Nii Osae Osae Dade, Wenhao Yu, Lucas Krauß, Naman Jain, Yixuan Su, Xuanli He, Manan Dey, Edoardo Abati, Yekun Chai, Niklas Muennighoff, Xiangru Tang, Muhtasham Oblokulov, Christopher Akiki, Marc Marone, Chenghao Mou, Mayank Mishra, Alex Gu, Binyuan Hui, Tri Dao, Armel Zebaze, Olivier Dehaene, Nicolas Patry, Canwen Xu, Julian McAuley, Han Hu, Torsten Scholak, Sebastien Paquet, Jennifer Robinson, Carolyn Jane Anderson, Nicolas Chapados, Mostofa Patwary, Nima Tajbakhsh, Yacine Jernite, Carlos Muñoz Ferrandis, Lingming Zhang, Sean Hughes, Thomas Wolf, Arjun Guha, Leandro von Werra, and Harm de Vries.
149-
1. StyleTTS 2 (from Columbia University) released with the paper [StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models](https://huggingface.co/papers/2306.07691) by Yinghao Aaron Li, Cong Han, Vinay S. Raghavan, Gavin Mischler, Nima Mesgarani.
149+
1. **StyleTTS 2** (from Columbia University) released with the paper [StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models](https://huggingface.co/papers/2306.07691) by Yinghao Aaron Li, Cong Han, Vinay S. Raghavan, Gavin Mischler, Nima Mesgarani.
150+
1. **Supertonic** (from Supertone) released with the paper [Training Flow Matching Models with Reliable Labels via Self-Purification](https://huggingface.co/papers/2509.19091) by Hyeongju Kim, Yechan Yu, June Young Yi, Juheon Lee.
150151
1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://huggingface.co/papers/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
151152
1. **[Swin2SR](https://huggingface.co/docs/transformers/model_doc/swin2sr)** (from University of Würzburg) released with the paper [Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration](https://huggingface.co/papers/2209.11345) by Marcos V. Conde, Ui-Jin Choi, Maxime Burchi, Radu Timofte.
152153
1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://huggingface.co/papers/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.

src/models.js

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ import {
109109
std_mean,
110110
Tensor,
111111
DataTypeMap,
112+
randn,
112113
} from './utils/tensor.js';
113114
import { RawImage } from './utils/image.js';
114115

@@ -136,6 +137,7 @@ const MODEL_TYPES = {
136137
AudioTextToText: 10,
137138
AutoEncoder: 11,
138139
ImageAudioTextToText: 12,
140+
Supertonic: 13,
139141
}
140142
//////////////////////////////////////////////////
141143

@@ -1262,6 +1264,14 @@ export class PreTrainedModel extends Callable {
12621264
decoder_model: 'decoder_model',
12631265
}, options),
12641266
]);
1267+
} else if (modelType === MODEL_TYPES.Supertonic) {
1268+
info = await Promise.all([
1269+
constructSessions(pretrained_model_name_or_path, {
1270+
text_encoder: 'text_encoder',
1271+
latent_denoiser: 'latent_denoiser',
1272+
voice_decoder: 'voice_decoder',
1273+
}, options),
1274+
]);
12651275
} else { // should be MODEL_TYPES.EncoderOnly
12661276
if (modelType !== MODEL_TYPES.EncoderOnly) {
12671277
const type = modelName ?? config?.model_type;
@@ -6861,6 +6871,61 @@ export class SpeechT5HifiGan extends PreTrainedModel {
68616871
}
68626872
//////////////////////////////////////////////////
68636873

6874+
export class SupertonicPreTrainedModel extends PreTrainedModel { }
6875+
export class SupertonicForConditionalGeneration extends SupertonicPreTrainedModel {
6876+
6877+
async generate_speech({
6878+
// Required inputs
6879+
input_ids,
6880+
attention_mask,
6881+
style,
6882+
6883+
// Optional inputs
6884+
num_inference_steps = 5,
6885+
speed = 1.05,
6886+
}) {
6887+
// @ts-expect-error TS2339
6888+
const { sampling_rate, chunk_compress_factor, base_chunk_size, latent_dim } = this.config;
6889+
6890+
// 1. Text Encoder
6891+
const { last_hidden_state, durations } = await sessionRun(this.sessions['text_encoder'], {
6892+
input_ids, attention_mask, style,
6893+
});
6894+
durations.div_(speed); // Apply speed factor to duration
6895+
6896+
// 2. Latent Denoiser
6897+
const wav_len_max = durations.max().item() * sampling_rate;
6898+
const chunk_size = base_chunk_size * chunk_compress_factor;
6899+
const latent_len = Math.floor((wav_len_max + chunk_size - 1) / chunk_size);
6900+
const batch_size = input_ids.dims[0];
6901+
const latent_mask = ones([batch_size, latent_len]);
6902+
const num_steps = full([batch_size], num_inference_steps);
6903+
6904+
let noisy_latents = randn([batch_size, latent_dim * chunk_compress_factor, latent_len]);
6905+
for (let step = 0; step < num_inference_steps; ++step) {
6906+
const timestep = full([batch_size], step);
6907+
({ denoised_latents: noisy_latents } = await sessionRun(this.sessions['latent_denoiser'], {
6908+
style,
6909+
noisy_latents,
6910+
latent_mask,
6911+
encoder_outputs: last_hidden_state,
6912+
attention_mask,
6913+
timestep,
6914+
num_inference_steps: num_steps,
6915+
}));
6916+
}
6917+
6918+
// 3. Voice Decoder
6919+
const { waveform } = await sessionRun(this.sessions['voice_decoder'], {
6920+
latents: noisy_latents,
6921+
});
6922+
return {
6923+
waveform,
6924+
durations,
6925+
}
6926+
}
6927+
}
6928+
68646929

68656930
//////////////////////////////////////////////////
68666931
// TrOCR models
@@ -8000,6 +8065,7 @@ const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([
80008065
const MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = new Map([
80018066
['vits', ['VitsModel', VitsModel]],
80028067
['musicgen', ['MusicgenForConditionalGeneration', MusicgenForConditionalGeneration]],
8068+
['supertonic', ['SupertonicForConditionalGeneration', SupertonicForConditionalGeneration]],
80038069
]);
80048070

80058071
const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
@@ -8386,6 +8452,7 @@ const CUSTOM_MAPPING = [
83868452
['SnacDecoderModel', SnacDecoderModel, MODEL_TYPES.EncoderOnly],
83878453

83888454
['Gemma3nForConditionalGeneration', Gemma3nForConditionalGeneration, MODEL_TYPES.ImageAudioTextToText],
8455+
['SupertonicForConditionalGeneration', SupertonicForConditionalGeneration, MODEL_TYPES.Supertonic],
83898456
]
83908457
for (const [name, model, type] of CUSTOM_MAPPING) {
83918458
MODEL_TYPE_MAPPING.set(name, type);

src/pipelines.js

Lines changed: 73 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2792,6 +2792,9 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options:
27922792
*
27932793
* @typedef {Object} TextToAudioPipelineOptions Parameters specific to text-to-audio pipelines.
27942794
* @property {Tensor|Float32Array|string|URL} [speaker_embeddings=null] The speaker embeddings (if the model requires it).
2795+
* @property {number} [num_inference_steps] The number of denoising steps (if the model supports it).
2796+
* More denoising steps usually lead to higher quality audio but slower inference.
2797+
* @property {number} [speed] The speed of the generated audio (if the model supports it).
27952798
*
27962799
* @callback TextToAudioPipelineCallback Generates speech/audio from the inputs.
27972800
* @param {string|string[]} texts The text(s) to generate.
@@ -2805,31 +2808,24 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options:
28052808
* Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`.
28062809
* This pipeline generates an audio file from an input text and optional other conditional inputs.
28072810
*
2808-
* **Example:** Generate audio from text with `Xenova/speecht5_tts`.
2811+
* **Example:** Generate audio from text with `onnx-community/Supertonic-TTS-ONNX`.
28092812
* ```javascript
2810-
* const synthesizer = await pipeline('text-to-speech', 'Xenova/speecht5_tts', { quantized: false });
2811-
* const speaker_embeddings = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin';
2812-
* const out = await synthesizer('Hello, my dog is cute', { speaker_embeddings });
2813+
* const synthesizer = await pipeline('text-to-speech', 'onnx-community/Supertonic-TTS-ONNX');
2814+
* const speaker_embeddings = 'https://huggingface.co/onnx-community/Supertonic-TTS-ONNX/resolve/main/voices/F1.bin';
2815+
* const output = await synthesizer('Hello there, how are you doing?', { speaker_embeddings });
28132816
* // RawAudio {
2814-
* // audio: Float32Array(26112) [-0.00005657337896991521, 0.00020583874720614403, ...],
2815-
* // sampling_rate: 16000
2817+
* // audio: Float32Array(95232) [-0.000482565927086398, -0.0004853440332226455, ...],
2818+
* // sampling_rate: 44100
28162819
* // }
2817-
* ```
2818-
*
2819-
* You can then save the audio to a .wav file with the `wavefile` package:
2820-
* ```javascript
2821-
* import wavefile from 'wavefile';
2822-
* import fs from 'fs';
2823-
*
2824-
* const wav = new wavefile.WaveFile();
2825-
* wav.fromScratch(1, out.sampling_rate, '32f', out.audio);
2826-
* fs.writeFileSync('out.wav', wav.toBuffer());
2820+
*
2821+
* // Optional: Save the audio to a .wav file or Blob
2822+
* await output.save('output.wav'); // You can also use `output.toBlob()` to access the audio as a Blob
28272823
* ```
28282824
*
28292825
* **Example:** Multilingual speech generation with `Xenova/mms-tts-fra`. See [here](https://huggingface.co/models?pipeline_tag=text-to-speech&other=vits&sort=trending) for the full list of available languages (1107).
28302826
* ```javascript
28312827
* const synthesizer = await pipeline('text-to-speech', 'Xenova/mms-tts-fra');
2832-
* const out = await synthesizer('Bonjour');
2828+
* const output = await synthesizer('Bonjour');
28332829
* // RawAudio {
28342830
* // audio: Float32Array(23808) [-0.00037693005288019776, 0.0003325853613205254, ...],
28352831
* // sampling_rate: 16000
@@ -2850,20 +2846,76 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi
28502846
this.vocoder = options.vocoder ?? null;
28512847
}
28522848

2849+
async _prepare_speaker_embeddings(speaker_embeddings) {
2850+
// Load speaker embeddings as Float32Array from path/URL
2851+
if (typeof speaker_embeddings === 'string' || speaker_embeddings instanceof URL) {
2852+
// Load from URL with fetch
2853+
speaker_embeddings = new Float32Array(
2854+
await (await fetch(speaker_embeddings)).arrayBuffer()
2855+
);
2856+
}
2857+
2858+
if (speaker_embeddings instanceof Float32Array) {
2859+
speaker_embeddings = new Tensor(
2860+
'float32',
2861+
speaker_embeddings,
2862+
[speaker_embeddings.length]
2863+
)
2864+
} else if (!(speaker_embeddings instanceof Tensor)) {
2865+
throw new Error("Speaker embeddings must be a `Tensor`, `Float32Array`, `string`, or `URL`.")
2866+
}
2867+
2868+
return speaker_embeddings;
2869+
}
28532870

28542871
/** @type {TextToAudioPipelineCallback} */
28552872
async _call(text_inputs, {
28562873
speaker_embeddings = null,
2874+
num_inference_steps,
2875+
speed,
28572876
} = {}) {
28582877

28592878
// If this.processor is not set, we are using a `AutoModelForTextToWaveform` model
28602879
if (this.processor) {
28612880
return this._call_text_to_spectrogram(text_inputs, { speaker_embeddings });
2881+
} else if (
2882+
this.model.config.model_type === "supertonic"
2883+
) {
2884+
return this._call_supertonic(text_inputs, { speaker_embeddings, num_inference_steps, speed });
28622885
} else {
28632886
return this._call_text_to_waveform(text_inputs);
28642887
}
28652888
}
28662889

2890+
async _call_supertonic(text_inputs, { speaker_embeddings, num_inference_steps, speed }) {
2891+
if (!speaker_embeddings) {
2892+
throw new Error("Speaker embeddings must be provided for Supertonic models.");
2893+
}
2894+
speaker_embeddings = await this._prepare_speaker_embeddings(speaker_embeddings);
2895+
2896+
// @ts-expect-error TS2339
2897+
const { sampling_rate, style_dim } = this.model.config;
2898+
2899+
speaker_embeddings = (/** @type {Tensor} */ (speaker_embeddings)).view(1, -1, style_dim);
2900+
const inputs = this.tokenizer(text_inputs, {
2901+
padding: true,
2902+
truncation: true,
2903+
});
2904+
2905+
// @ts-expect-error TS2339
2906+
const { waveform } = await this.model.generate_speech({
2907+
...inputs,
2908+
style: speaker_embeddings,
2909+
num_inference_steps,
2910+
speed,
2911+
});
2912+
2913+
return new RawAudio(
2914+
waveform.data,
2915+
sampling_rate,
2916+
)
2917+
}
2918+
28672919
async _call_text_to_waveform(text_inputs) {
28682920

28692921
// Run tokenization
@@ -2891,32 +2943,16 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi
28912943
this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { dtype: 'fp32' });
28922944
}
28932945

2894-
// Load speaker embeddings as Float32Array from path/URL
2895-
if (typeof speaker_embeddings === 'string' || speaker_embeddings instanceof URL) {
2896-
// Load from URL with fetch
2897-
speaker_embeddings = new Float32Array(
2898-
await (await fetch(speaker_embeddings)).arrayBuffer()
2899-
);
2900-
}
2901-
2902-
if (speaker_embeddings instanceof Float32Array) {
2903-
speaker_embeddings = new Tensor(
2904-
'float32',
2905-
speaker_embeddings,
2906-
[1, speaker_embeddings.length]
2907-
)
2908-
} else if (!(speaker_embeddings instanceof Tensor)) {
2909-
throw new Error("Speaker embeddings must be a `Tensor`, `Float32Array`, `string`, or `URL`.")
2910-
}
2911-
29122946
// Run tokenization
29132947
const { input_ids } = this.tokenizer(text_inputs, {
29142948
padding: true,
29152949
truncation: true,
29162950
});
29172951

2918-
// NOTE: At this point, we are guaranteed that `speaker_embeddings` is a `Tensor`
2919-
// @ts-ignore
2952+
speaker_embeddings = await this._prepare_speaker_embeddings(speaker_embeddings);
2953+
speaker_embeddings = speaker_embeddings.view(1, -1);
2954+
2955+
// @ts-expect-error TS2339
29202956
const { waveform } = await this.model.generate_speech(input_ids, speaker_embeddings, { vocoder: this.vocoder });
29212957

29222958
const sampling_rate = this.processor.feature_extractor.config.sampling_rate;

0 commit comments

Comments
 (0)