Skip to content

Commit 6185106

Browse files
ylacombesanchit-gandhiVaibhavs10
authored
V02 release (#94)
* bump version to v0.2 * adapt readme * Update README.md * update README * add inference tips + streamer class * update readme * Update README.md * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <[email protected]> * Update README * Apply suggestions from code review Co-authored-by: Vaibhav Srivastav <[email protected]> --------- Co-authored-by: Sanchit Gandhi <[email protected]> Co-authored-by: Vaibhav Srivastav <[email protected]>
1 parent 1551b7c commit 6185106

File tree

5 files changed

+413
-19
lines changed

5 files changed

+413
-19
lines changed

INFERENCE.md

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Inference tips
2+
3+
Parler-TTS benefits from a number of optimizations that can make the model up to 4x faster. Add to this the ability to stream audio as it's being generated, and you can achieve time-to-first audio in under 500ms on a modern GPU.
4+
5+
## 📖 Quick Index
6+
* [Efficient Attention Implementation](#efficient-attention-implementations)
7+
* [Compilation](#compilation)
8+
* [Streaming](#streaming)
9+
* [Batch generation](#batch-generation)
10+
11+
## Efficient Attention implementations
12+
13+
Parler-TTS supports [SDPA](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) and [Flash Attention 2](https://github.com/Dao-AILab/flash-attention).
14+
15+
SDPA is used by default and speeds up generation time by up to 1.4x compared with eager attention.
16+
17+
To switch between attention implementations, simply specify `attn_implementation=attn_implementation` when loading the checkpoints:
18+
19+
```py
20+
from parler_tts import ParlerTTSForConditionalGeneration
21+
22+
torch_device = "cuda:0" # use "mps" for Mac
23+
torch_dtype = torch.bfloat16
24+
model_name = "parler-tts/parler-tts-mini-v1"
25+
26+
attn_implementation = "eager" # "sdpa" or "flash_attention_2"
27+
28+
model = ParlerTTSForConditionalGeneration.from_pretrained(
29+
model_name,
30+
attn_implementation=attn_implementation
31+
).to(torch_device, dtype=torch_dtype)
32+
```
33+
34+
## Compilation
35+
36+
[Compiling](https://pytorch.org/docs/stable/generated/torch.compile.html) the forward method of Parler can speed up generation time by up to 4.5x.
37+
38+
As an indication, `mode=default` brings a speed-up of 1.4 times compared to no compilation, while `mode="reduce-overhead"` brings much faster generation, at the cost of a longer compilation time and the need to generate twice to see the benefits of compilation.
39+
40+
```py
41+
import torch
42+
from parler_tts import ParlerTTSForConditionalGeneration
43+
from transformers import AutoTokenizer
44+
45+
torch_device = "cuda:0"
46+
torch_dtype = torch.bfloat16
47+
model_name = "parler-tts/parler-tts-mini-v1"
48+
49+
# need to set padding max length
50+
max_length = 50
51+
52+
# load model and tokenizer
53+
tokenizer = AutoTokenizer.from_pretrained(model_name)
54+
model = ParlerTTSForConditionalGeneration.from_pretrained(
55+
model_name,
56+
attn_implementation="eager"
57+
).to(torch_device, dtype=torch_dtype)
58+
59+
# compile the forward pass
60+
compile_mode = "default" # chose "reduce-overhead" for 3 to 4x speed-up
61+
model.generation_config.cache_implementation = "static"
62+
model.forward = torch.compile(model.forward, mode=compile_mode)
63+
64+
# warmup
65+
inputs = tokenizer("This is for compilation", return_tensors="pt", padding="max_length", max_length=max_length).to(device)
66+
67+
model_kwargs = {**inputs, "prompt_input_ids": inputs.input_ids, "prompt_attention_mask": inputs.attention_mask, }
68+
69+
n_steps = 1 if compile_mode == "default" else 2
70+
for _ in range(n_steps):
71+
_ = model.generate(**model_kwargs)
72+
73+
74+
# now you can benefit from compilation speed-ups
75+
...
76+
77+
```
78+
79+
80+
## Streaming
81+
82+
### How Does It Work?
83+
84+
Parler-TTS is an auto-regressive transformer-based model, meaning generates audio codes (tokens) in a causal fashion.
85+
86+
At each decoding step, the model generates a new set of audio codes, conditional on the text input and all previous audio codes. From the
87+
frame rate of the [DAC model](https://huggingface.co/parler-tts/dac_44khZ_8kbps) used to decode the generated codes to audio waveform, each set of generated audio codes corresponds to 0.011 seconds. This means we require a total of 1720 decoding steps to generate 20 seconds of audio.
88+
89+
Rather than waiting for the entire audio sequence to be generated, which would require the full 1720 decoding steps, we can start playing the audio after a specified number of decoding steps have been reached, a techinque known as [*streaming*](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming).
90+
For example, after 86 steps we have the first second of audio ready, and so can play this without waiting for the remaining decoding steps to be complete. As we continue to generate with the Parler-TTS model, we append new chunks of generated audio to our output waveform on-the-fly. After the full 1720 decoding steps, the generated audio is complete, and is composed of 20 chunks of audio, each corresponding to 86 tokens.
91+
This method of playing incremental generations reduces the latency of the Parler-TTS model from the total time to generate 1720 tokens, to the time taken to play the first chunk of audio (86 tokens). This can result in significant improvements to perceived latency, particularly when the chunk size is chosen to be small. In practice, the chunk size should be tuned to your device: using a smaller chunk size will mean that the first chunk is ready faster, but should not be chosen so small that the model generates slower than the time it takes to play the audio.
92+
93+
94+
### How Can I Use It?
95+
96+
We've added [ParlerTTSStreamer](https://github.com/huggingface/parler-tts/blob/main/parler_tts/streamer.py) to the library. Don't hesitate to adapt it to your use-case.
97+
98+
Here's how to create a generator out of the streamer.
99+
100+
```py
101+
import torch
102+
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
103+
from transformers import AutoTokenizer
104+
from threading import Thread
105+
106+
torch_device = "cuda:0" # Use "mps" for Mac
107+
torch_dtype = torch.bfloat16
108+
model_name = "parler-tts/parler-tts-mini-v1"
109+
110+
# need to set padding max length
111+
max_length = 50
112+
113+
# load model and tokenizer
114+
tokenizer = AutoTokenizer.from_pretrained(model_name)
115+
model = ParlerTTSForConditionalGeneration.from_pretrained(
116+
model_name,
117+
).to(torch_device, dtype=torch_dtype)
118+
119+
sampling_rate = model.audio_encoder.config.sampling_rate
120+
frame_rate = model.audio_encoder.config.frame_rate
121+
122+
def generate(text, description, play_steps_in_s=0.5):
123+
play_steps = int(frame_rate * play_steps_in_s)
124+
streamer = ParlerTTSStreamer(model, device=torch_device, play_steps=play_steps)
125+
# tokenization
126+
inputs = tokenizer(description, return_tensors="pt").to(torch_device)
127+
prompt = tokenizer(text, return_tensors="pt").to(torch_device)
128+
# create generation kwargs
129+
generation_kwargs = dict(
130+
input_ids=inputs.input_ids,
131+
prompt_input_ids=prompt.input_ids,
132+
attention_mask=inputs.attention_mask,
133+
prompt_attention_mask=prompt.attention_mask,
134+
streamer=streamer,
135+
do_sample=True,
136+
temperature=1.0,
137+
min_new_tokens=10,
138+
)
139+
# initialize Thread
140+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
141+
thread.start()
142+
# iterate over chunks of audio
143+
for new_audio in streamer:
144+
if new_audio.shape[0] == 0:
145+
break
146+
print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 4)} seconds")
147+
yield sampling_rate, new_audio
148+
149+
150+
# now you can do
151+
text = "This is a test of the streamer class"
152+
description = "Jon's talking really fast."
153+
154+
chunk_size_in_s = 0.5
155+
156+
for (sampling_rate, audio_chunk) in generate(text, description, chunk_size_in_s):
157+
# You can do everything that you need with the chunk now
158+
# For example: stream it, save it, play it.
159+
print(audio_chunk.shape)
160+
```
161+
162+
## Batch generation
163+
164+
Batching means combining operations for multiple samples to bring the overall time spent generating the samples lower than generating sample per sample.
165+
166+
Here is a quick example of how you can use it:
167+
168+
```py
169+
from parler_tts import ParlerTTSForConditionalGeneration
170+
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
171+
import scipy
172+
173+
174+
repo_id = "parler-tts/parler-tts-mini-v1"
175+
176+
model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to("cuda")
177+
tokenizer = AutoTokenizer.from_pretrained(repo_id, padding_side="left")
178+
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
179+
180+
input_text = ["Hey, how are you doing?", "I'm not sure how to feel about it."]
181+
description = 2 * ["A male speaker with a monotone and high-pitched voice is delivering his speech at a really low speed in a confined environment."]
182+
183+
inputs = tokenizer(description, return_tensors="pt", padding=True).to("cuda")
184+
prompt = tokenizer(input_text, return_tensors="pt", padding=True).to("cuda")
185+
186+
set_seed(0)
187+
generation = model.generate(
188+
input_ids=inputs.input_ids,
189+
attention_mask=inputs.attention_mask,
190+
prompt_input_ids=prompt.input_ids,
191+
prompt_attention_mask=prompt.attention_mask,
192+
do_sample=True,
193+
return_dict_in_generate=True,
194+
)
195+
196+
audio_1 = generation.sequences[0, :generation.audios_length[0]]
197+
audio_2 = generation.sequences[1, :generation.audios_length[1]]
198+
199+
print(audio_1.shape, audio_2.shape)
200+
scipy.io.wavfile.write("sample_out.wav", rate=feature_extractor.sampling_rate, data=audio_1.cpu().numpy().squeeze())
201+
scipy.io.wavfile.write("sample_out_2.wav", rate=feature_extractor.sampling_rate, data=audio_2.cpu().numpy().squeeze())
202+
```

README.md

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,23 @@ Contrarily to other TTS models, Parler-TTS is a **fully open-source** release. A
77
This repository contains the inference and training code for Parler-TTS. It is designed to accompany the [Data-Speech](https://github.com/huggingface/dataspeech) repository for dataset annotation.
88

99
> [!IMPORTANT]
10-
> We're proud to release [Parler-TTS Mini v0.1](https://huggingface.co/parler-tts/parler_tts_mini_v0.1), our first 600M parameter model, trained on 10.5K hours of audio data.
11-
> In the coming weeks, we'll be working on scaling up to 50k hours of data, in preparation for the v1 model.
10+
> **08/08/2024:** We are proud to release two new Parler-TTS checkpoints:
11+
> 1. [Parler-TTS Mini](https://huggingface.co/parler-tts/parler-tts-mini-v1), an 880M parameter model.
12+
> 2. [Parler-TTS Large](https://huggingface.co/parler-tts/parler-tts-large-v1), a 2.3B parameter model.
13+
>
14+
> These checkpoints have been trained on 45k hours of audiobook data.
15+
>
16+
> In addition, the code is optimized for much faster generation: we've added SDPA and Flash Attention 2 compatibility, as well as the ability to compile the model.
1217
1318
## 📖 Quick Index
1419
* [Installation](#installation)
1520
* [Usage](#usage)
21+
- [🎲 Using a random voice](#-random-voice)
22+
- [🎯 Using a specific speaker](#-using-a-specific-speaker)
1623
* [Training](#training)
17-
* [Demo](https://huggingface.co/spaces/parler-tts/parler_tts_mini)
24+
* [Demo](https://huggingface.co/spaces/parler-tts/parler_tts)
1825
* [Model weights and datasets](https://huggingface.co/parler-tts)
26+
* [Optimizing inference](#-optimizing-inference-speed)
1927

2028
## Installation
2129

@@ -34,43 +42,85 @@ pip3 install --pre torch torchaudio --index-url https://download.pytorch.org/whl
3442
## Usage
3543

3644
> [!TIP]
37-
> You can directly try it out in an interactive demo [here](https://huggingface.co/spaces/parler-tts/parler_tts_mini)!
45+
> You can directly try it out in an interactive demo [here](https://huggingface.co/spaces/parler-tts/parler_tts)!
3846
39-
Using Parler-TTS is as simple as "bonjour". Simply use the following inference snippet.
47+
Using Parler-TTS is as simple as "bonjour". Simply install the library once:
48+
49+
```sh
50+
pip install git+https://github.com/huggingface/parler-tts.git
51+
```
52+
53+
### 🎲 Random voice
54+
55+
56+
**Parler-TTS** has been trained to generate speech with features that can be controlled with a simple text prompt, for example:
4057

4158
```py
59+
import torch
4260
from parler_tts import ParlerTTSForConditionalGeneration
4361
from transformers import AutoTokenizer
4462
import soundfile as sf
45-
import torch
4663

47-
device = "cpu"
48-
if torch.cuda.is_available():
49-
device = "cuda:0"
50-
if torch.backends.mps.is_available():
51-
device = "mps"
52-
if torch.xpu.is_available():
53-
device = "xpu"
54-
torch_dtype = torch.float16 if device != "cpu" else torch.float32
64+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
5565

56-
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1", torch_dtype=torch_dtype).to(device)
66+
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
67+
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
5768

58-
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
69+
prompt = "Hey, how are you doing today?"
70+
description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
71+
72+
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
73+
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
74+
75+
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
76+
audio_arr = generation.cpu().numpy().squeeze()
77+
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
78+
```
79+
80+
### 🎯 Using a specific speaker
81+
82+
To ensure speaker consistency across generations, this checkpoint was also trained on 34 speakers, characterized by name (e.g. Jon, Lea, Gary, Jenna, Mike, Laura).
83+
84+
To take advantage of this, simply adapt your text description to specify which speaker to use: `Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise.`
85+
86+
```py
87+
import torch
88+
from parler_tts import ParlerTTSForConditionalGeneration
89+
from transformers import AutoTokenizer
90+
import soundfile as sf
91+
92+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
93+
94+
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
95+
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
5996

6097
prompt = "Hey, how are you doing today?"
61-
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."
98+
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
6299

63100
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
64101
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
65102

66-
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids).to(torch.float32)
103+
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
67104
audio_arr = generation.cpu().numpy().squeeze()
68105
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
69106
```
70107

108+
**Tips**:
109+
* Include the term "very clear audio" to generate the highest quality audio, and "very noisy audio" for high levels of background noise
110+
* Punctuation can be used to control the prosody of the generations, e.g. use commas to add small breaks in speech
111+
* The remaining speech features (gender, speaking rate, pitch and reverberation) can be controlled directly through the prompt
112+
113+
### ✨ Optimizing Inference Speed
114+
115+
We've set up an [inference guide](INFERENCE.md) to make generation faster. Think SDPA, torch.compile and streaming!
116+
117+
71118
https://github.com/huggingface/parler-tts/assets/52246514/251e2488-fe6e-42c1-81cd-814c5b7795b0
72119

73120
## Training
121+
> [!WARNING]
122+
> The training guide has yet to be adapted to the newest checkpoints.
123+
74124
<a target="_blank" href="https://colab.research.google.com/github/ylacombe/scripts_and_notebooks/blob/main/Finetuning_Parler_TTS_on_a_single_speaker_dataset.ipynb">
75125
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
76126
</a>

parler_tts/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1"
1+
__version__ = "0.2"
22

33

44
from transformers import AutoConfig, AutoModel
@@ -12,6 +12,7 @@
1212
build_delay_pattern_mask,
1313
)
1414

15+
from .streamer import ParlerTTSStreamer
1516

1617
AutoConfig.register("dac", DACConfig)
1718
AutoModel.register(DACConfig, DACModel)

0 commit comments

Comments
 (0)