Skip to content

Commit 56a7608

Browse files
tuanh123789DN6
andauthored
Add AudioLDM2 TTS (#5381)
* add audioldm2 tts * change gpt2 max new tokens * remove unnecessary pipeline and class * add TTS to AudioLDM2Pipeline * add TTS docs * delete unnecessary file * remove unnecessary import * add audioldm2 slow testcase * fix code quality * remove AudioLDMLearnablePositionalEmbedding * add variable check vits encoder * add use_learned_position_embedding --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent 6133d98 commit 56a7608

File tree

4 files changed

+152
-13
lines changed

4 files changed

+152
-13
lines changed

docs/source/en/api/pipelines/audioldm2.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ The abstract of the paper is the following:
2020

2121
*Although audio generation shares commonalities across different types of audio, such as speech, music, and sound effects, designing models for each type requires careful consideration of specific objectives and biases that can significantly differ from those of other types. To bring us closer to a unified perspective of audio generation, this paper proposes a framework that utilizes the same learning method for speech, music, and sound effect generation. Our framework introduces a general representation of audio, called "language of audio" (LOA). Any audio can be translated into LOA based on AudioMAE, a self-supervised pre-trained representation learning model. In the generation process, we translate any modalities into LOA by using a GPT-2 model, and we perform self-supervised audio generation learning with a latent diffusion model conditioned on LOA. The proposed framework naturally brings advantages such as in-context learning abilities and reusable self-supervised pretrained AudioMAE and latent diffusion models. Experiments on the major benchmarks of text-to-audio, text-to-music, and text-to-speech demonstrate state-of-the-art or competitive performance against previous approaches. Our code, pretrained model, and demo are available at [this https URL](https://audioldm.github.io/audioldm2).*
2222

23-
This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original codebase can be found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).
23+
This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi) and [Nguyễn Công Tú Anh](https://github.com/tuanh123789). The original codebase can be
24+
found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).
2425

2526
## Tips
2627

@@ -36,6 +37,8 @@ See table below for details on the three checkpoints:
3637
| [audioldm2](https://huggingface.co/cvssp/audioldm2) | Text-to-audio | 350M | 1.1B | 1150k |
3738
| [audioldm2-large](https://huggingface.co/cvssp/audioldm2-large) | Text-to-audio | 750M | 1.5B | 1150k |
3839
| [audioldm2-music](https://huggingface.co/cvssp/audioldm2-music) | Text-to-music | 350M | 1.1B | 665k |
40+
| [audioldm2-gigaspeech](https://huggingface.co/anhnct/audioldm2_gigaspeech) | Text-to-speech | 350M | 1.1B |10k |
41+
| [audioldm2-ljspeech](https://huggingface.co/anhnct/audioldm2_ljspeech) | Text-to-speech | 350M | 1.1B | |
3942

4043
### Constructing a prompt
4144

@@ -53,7 +56,7 @@ See table below for details on the three checkpoints:
5356
* The quality of the generated waveforms can vary significantly based on the seed. Try generating with different seeds until you find a satisfactory generation.
5457
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
5558

56-
The following example demonstrates how to construct good music generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).
59+
The following example demonstrates how to construct good music and speech generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).
5760

5861
<Tip>
5962

src/diffusers/pipelines/audioldm2/modeling_audioldm2.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,14 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
9595
"""
9696

9797
@register_to_config
98-
def __init__(self, text_encoder_dim, text_encoder_1_dim, langauge_model_dim):
98+
def __init__(
99+
self,
100+
text_encoder_dim,
101+
text_encoder_1_dim,
102+
langauge_model_dim,
103+
use_learned_position_embedding=None,
104+
max_seq_length=None,
105+
):
99106
super().__init__()
100107
# additional projection layers for each text encoder
101108
self.projection = nn.Linear(text_encoder_dim, langauge_model_dim)
@@ -108,6 +115,14 @@ def __init__(self, text_encoder_dim, text_encoder_1_dim, langauge_model_dim):
108115
self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
109116
self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
110117

118+
self.use_learned_position_embedding = use_learned_position_embedding
119+
120+
# learable positional embedding for vits encoder
121+
if self.use_learned_position_embedding is not None:
122+
self.learnable_positional_embedding = torch.nn.Parameter(
123+
torch.zeros((1, text_encoder_1_dim, max_seq_length))
124+
)
125+
111126
def forward(
112127
self,
113128
hidden_states: Optional[torch.FloatTensor] = None,
@@ -120,6 +135,10 @@ def forward(
120135
hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed
121136
)
122137

138+
# Add positional embedding for Vits hidden state
139+
if self.use_learned_position_embedding is not None:
140+
hidden_states_1 = (hidden_states_1.permute(0, 2, 1) + self.learnable_positional_embedding).permute(0, 2, 1)
141+
123142
hidden_states_1 = self.projection_1(hidden_states_1)
124143
hidden_states_1, attention_mask_1 = add_special_tokens(
125144
hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1

src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py

Lines changed: 94 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
T5EncoderModel,
2828
T5Tokenizer,
2929
T5TokenizerFast,
30+
VitsModel,
31+
VitsTokenizer,
3032
)
3133

3234
from ...models import AutoencoderKL
@@ -79,6 +81,37 @@
7981
>>> # save the best audio sample (index 0) as a .wav file
8082
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
8183
```
84+
```
85+
#Using AudioLDM2 for Text To Speech
86+
>>> import scipy
87+
>>> import torch
88+
>>> from diffusers import AudioLDM2Pipeline
89+
90+
>>> repo_id = "anhnct/audioldm2_gigaspeech"
91+
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
92+
>>> pipe = pipe.to("cuda")
93+
94+
>>> # define the prompts
95+
>>> prompt = "A female reporter is speaking"
96+
>>> transcript = "wish you have a good day"
97+
98+
>>> # set the seed for generator
99+
>>> generator = torch.Generator("cuda").manual_seed(0)
100+
101+
>>> # run the generation
102+
>>> audio = pipe(
103+
... prompt,
104+
... transcription=transcript,
105+
... num_inference_steps=200,
106+
... audio_length_in_s=10.0,
107+
... num_waveforms_per_prompt=2,
108+
... generator=generator,
109+
... max_new_tokens=512, #Must set max_new_tokens equa to 512 for TTS
110+
... ).audios
111+
112+
>>> # save the best audio sample (index 0) as a .wav file
113+
>>> scipy.io.wavfile.write("tts.wav", rate=16000, data=audio[0])
114+
```
82115
"""
83116

84117

@@ -116,20 +149,23 @@ class AudioLDM2Pipeline(DiffusionPipeline):
116149
specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
117150
text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
118151
rank generated waveforms against the text prompt by computing similarity scores.
119-
text_encoder_2 ([`~transformers.T5EncoderModel`]):
152+
text_encoder_2 ([`~transformers.T5EncoderModel`, `~transformers.VitsModel`]):
120153
Second frozen text-encoder. AudioLDM2 uses the encoder of
121154
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
122-
[google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant.
155+
[google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant. Second frozen text-encoder use
156+
for TTS. AudioLDM2 uses the encoder of
157+
[Vits](https://huggingface.co/docs/transformers/model_doc/vits#transformers.VitsModel).
123158
projection_model ([`AudioLDM2ProjectionModel`]):
124159
A trained model used to linearly project the hidden-states from the first and second text encoder models
125160
and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
126-
concatenated to give the input to the language model.
161+
concatenated to give the input to the language model. A Learned Position Embedding for the Vits
162+
hidden-states
127163
language_model ([`~transformers.GPT2Model`]):
128164
An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
129165
outputs from the two text encoders.
130166
tokenizer ([`~transformers.RobertaTokenizer`]):
131167
Tokenizer to tokenize text for the first frozen text-encoder.
132-
tokenizer_2 ([`~transformers.T5Tokenizer`]):
168+
tokenizer_2 ([`~transformers.T5Tokenizer`, `~transformers.VitsTokenizer`]):
133169
Tokenizer to tokenize text for the second frozen text-encoder.
134170
feature_extractor ([`~transformers.ClapFeatureExtractor`]):
135171
Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
@@ -146,11 +182,11 @@ def __init__(
146182
self,
147183
vae: AutoencoderKL,
148184
text_encoder: ClapModel,
149-
text_encoder_2: T5EncoderModel,
185+
text_encoder_2: Union[T5EncoderModel, VitsModel],
150186
projection_model: AudioLDM2ProjectionModel,
151187
language_model: GPT2Model,
152188
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
153-
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast],
189+
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
154190
feature_extractor: ClapFeatureExtractor,
155191
unet: AudioLDM2UNet2DConditionModel,
156192
scheduler: KarrasDiffusionSchedulers,
@@ -273,6 +309,7 @@ def encode_prompt(
273309
device,
274310
num_waveforms_per_prompt,
275311
do_classifier_free_guidance,
312+
transcription=None,
276313
negative_prompt=None,
277314
prompt_embeds: Optional[torch.FloatTensor] = None,
278315
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -288,6 +325,8 @@ def encode_prompt(
288325
Args:
289326
prompt (`str` or `List[str]`, *optional*):
290327
prompt to be encoded
328+
transcription (`str` or `List[str]`):
329+
transcription of text to speech
291330
device (`torch.device`):
292331
torch device
293332
num_waveforms_per_prompt (`int`):
@@ -368,16 +407,26 @@ def encode_prompt(
368407

369408
# Define tokenizers and text encoders
370409
tokenizers = [self.tokenizer, self.tokenizer_2]
371-
text_encoders = [self.text_encoder, self.text_encoder_2]
410+
is_vits_text_encoder = isinstance(self.text_encoder_2, VitsModel)
411+
412+
if is_vits_text_encoder:
413+
text_encoders = [self.text_encoder, self.text_encoder_2.text_encoder]
414+
else:
415+
text_encoders = [self.text_encoder, self.text_encoder_2]
372416

373417
if prompt_embeds is None:
374418
prompt_embeds_list = []
375419
attention_mask_list = []
376420

377421
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
422+
use_prompt = isinstance(
423+
tokenizer, (RobertaTokenizer, RobertaTokenizerFast, T5Tokenizer, T5TokenizerFast)
424+
)
378425
text_inputs = tokenizer(
379-
prompt,
380-
padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True,
426+
prompt if use_prompt else transcription,
427+
padding="max_length"
428+
if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))
429+
else True,
381430
max_length=tokenizer.model_max_length,
382431
truncation=True,
383432
return_tensors="pt",
@@ -407,6 +456,18 @@ def encode_prompt(
407456
prompt_embeds = prompt_embeds[:, None, :]
408457
# make sure that we attend to this single hidden-state
409458
attention_mask = attention_mask.new_ones((batch_size, 1))
459+
elif is_vits_text_encoder:
460+
# Add end_token_id and attention mask in the end of sequence phonemes
461+
for text_input_id, text_attention_mask in zip(text_input_ids, attention_mask):
462+
for idx, phoneme_id in enumerate(text_input_id):
463+
if phoneme_id == 0:
464+
text_input_id[idx] = 182
465+
text_attention_mask[idx] = 1
466+
break
467+
prompt_embeds = text_encoder(
468+
text_input_ids, attention_mask=attention_mask, padding_mask=attention_mask.unsqueeze(-1)
469+
)
470+
prompt_embeds = prompt_embeds[0]
410471
else:
411472
prompt_embeds = text_encoder(
412473
text_input_ids,
@@ -485,7 +546,7 @@ def encode_prompt(
485546
uncond_tokens,
486547
padding="max_length",
487548
max_length=tokenizer.model_max_length
488-
if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
549+
if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))
489550
else max_length,
490551
truncation=True,
491552
return_tensors="pt",
@@ -503,6 +564,15 @@ def encode_prompt(
503564
negative_prompt_embeds = negative_prompt_embeds[:, None, :]
504565
# make sure that we attend to this single hidden-state
505566
negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
567+
elif is_vits_text_encoder:
568+
negative_prompt_embeds = torch.zeros(
569+
batch_size,
570+
tokenizer.model_max_length,
571+
text_encoder.config.hidden_size,
572+
).to(dtype=self.text_encoder_2.dtype, device=device)
573+
negative_attention_mask = torch.zeros(batch_size, tokenizer.model_max_length).to(
574+
dtype=self.text_encoder_2.dtype, device=device
575+
)
506576
else:
507577
negative_prompt_embeds = text_encoder(
508578
uncond_input_ids,
@@ -623,6 +693,7 @@ def check_inputs(
623693
audio_length_in_s,
624694
vocoder_upsample_factor,
625695
callback_steps,
696+
transcription=None,
626697
negative_prompt=None,
627698
prompt_embeds=None,
628699
negative_prompt_embeds=None,
@@ -690,6 +761,14 @@ def check_inputs(
690761
f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
691762
)
692763

764+
if transcription is None:
765+
if self.text_encoder_2.config.model_type == "vits":
766+
raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription")
767+
elif transcription is not None and (
768+
not isinstance(transcription, str) and not isinstance(transcription, list)
769+
):
770+
raise ValueError(f"`transcription` has to be of type `str` or `list` but is {type(transcription)}")
771+
693772
if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
694773
if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
695774
raise ValueError(
@@ -734,6 +813,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, devic
734813
def __call__(
735814
self,
736815
prompt: Union[str, List[str]] = None,
816+
transcription: Union[str, List[str]] = None,
737817
audio_length_in_s: Optional[float] = None,
738818
num_inference_steps: int = 200,
739819
guidance_scale: float = 3.5,
@@ -761,6 +841,8 @@ def __call__(
761841
Args:
762842
prompt (`str` or `List[str]`, *optional*):
763843
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
844+
transcription (`str` or `List[str]`, *optional*):\
845+
The transcript for text to speech.
764846
audio_length_in_s (`int`, *optional*, defaults to 10.24):
765847
The length of the generated audio sample in seconds.
766848
num_inference_steps (`int`, *optional*, defaults to 200):
@@ -857,6 +939,7 @@ def __call__(
857939
audio_length_in_s,
858940
vocoder_upsample_factor,
859941
callback_steps,
942+
transcription,
860943
negative_prompt,
861944
prompt_embeds,
862945
negative_prompt_embeds,
@@ -886,6 +969,7 @@ def __call__(
886969
device,
887970
num_waveforms_per_prompt,
888971
do_classifier_free_guidance,
972+
transcription,
889973
negative_prompt,
890974
prompt_embeds=prompt_embeds,
891975
negative_prompt_embeds=negative_prompt_embeds,

tests/pipelines/audioldm2/test_audioldm2.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,20 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0
516516
}
517517
return inputs
518518

519+
def get_inputs_tts(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
520+
generator = torch.Generator(device=generator_device).manual_seed(seed)
521+
latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
522+
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
523+
inputs = {
524+
"prompt": "A men saying",
525+
"transcription": "hello my name is John",
526+
"latents": latents,
527+
"generator": generator,
528+
"num_inference_steps": 3,
529+
"guidance_scale": 2.5,
530+
}
531+
return inputs
532+
519533
def test_audioldm2(self):
520534
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
521535
audioldm_pipe = audioldm_pipe.to(torch_device)
@@ -572,3 +586,22 @@ def test_audioldm2_large(self):
572586
)
573587
max_diff = np.abs(expected_slice - audio_slice).max()
574588
assert max_diff < 1e-3
589+
590+
def test_audioldm2_tts(self):
591+
audioldm_tts_pipe = AudioLDM2Pipeline.from_pretrained("anhnct/audioldm2_gigaspeech")
592+
audioldm_tts_pipe = audioldm_tts_pipe.to(torch_device)
593+
audioldm_tts_pipe.set_progress_bar_config(disable=None)
594+
595+
inputs = self.get_inputs_tts(torch_device)
596+
audio = audioldm_tts_pipe(**inputs).audios[0]
597+
598+
assert audio.ndim == 1
599+
assert len(audio) == 81952
600+
601+
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
602+
audio_slice = audio[8825:8835]
603+
expected_slice = np.array(
604+
[-0.1829, -0.1461, 0.0759, -0.1493, -0.1396, 0.5783, 0.3001, -0.3038, -0.0639, -0.2244]
605+
)
606+
max_diff = np.abs(expected_slice - audio_slice).max()
607+
assert max_diff < 1e-3

0 commit comments

Comments
 (0)