Skip to content

Commit 9271958

Browse files
authored
Merge branch 'main' into remove-deprecated-repository-class
2 parents a04d685 + b2b749d commit 9271958

File tree

10 files changed

+841
-798
lines changed

10 files changed

+841
-798
lines changed

README.md

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ Parler-TTS has light-weight dependencies and can be installed in one line:
2525
pip install git+https://github.com/huggingface/parler-tts.git
2626
```
2727

28+
Apple Silicon users will need to run a follow-up command to make use the nightly PyTorch (2.4) build for bfloat16 support:
29+
30+
```sh
31+
pip3 install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
32+
```
33+
2834
## Usage
2935

3036
> [!TIP]
@@ -38,9 +44,16 @@ from transformers import AutoTokenizer
3844
import soundfile as sf
3945
import torch
4046

41-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
47+
device = "cpu"
48+
if torch.cuda.is_available():
49+
device = "cuda:0"
50+
if torch.backends.mps.is_available():
51+
device = "mps"
52+
if torch.xpu.is_available():
53+
device = "xpu"
54+
torch_dtype = torch.float16 if device != "cpu" else torch.float32
4255

43-
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device)
56+
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)
4457
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
4558

4659
prompt = "Hey, how are you doing today?"
@@ -49,14 +62,17 @@ description = "A female speaker with a slightly low-pitched voice delivers her w
4962
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
5063
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
5164

52-
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
65+
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids).to(torch.float32)
5366
audio_arr = generation.cpu().numpy().squeeze()
5467
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
5568
```
5669

5770
https://github.com/huggingface/parler-tts/assets/52246514/251e2488-fe6e-42c1-81cd-814c5b7795b0
5871

5972
## Training
73+
<a target="_blank" href="https://colab.research.google.com/github/ylacombe/scripts_and_notebooks/blob/main/Finetuning_Parler_TTS_on_a_single_speaker_dataset.ipynb">
74+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
75+
</a>
6076

6177
The [training folder](/training/) contains all the information to train or fine-tune your own Parler-TTS model. It consists of:
6278
- [1. An introduction to the Parler-TTS architecture](/training/README.md#1-architecture)

parler_tts/modeling_parler_tts.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,8 +1386,6 @@ def generate(
13861386
batch_size = input_ids.shape[0] // self.num_codebooks
13871387

13881388
# 4. Define other model kwargs
1389-
model_kwargs["output_attentions"] = generation_config.output_attentions
1390-
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
13911389
model_kwargs["use_cache"] = generation_config.use_cache
13921390
model_kwargs["guidance_scale"] = generation_config.guidance_scale
13931391

@@ -1481,14 +1479,11 @@ def generate(
14811479
)
14821480

14831481
# 11. run greedy search
1484-
outputs = self.greedy_search(
1482+
outputs = self._greedy_search(
14851483
input_ids,
14861484
logits_processor=logits_processor,
14871485
stopping_criteria=stopping_criteria,
1488-
pad_token_id=generation_config.pad_token_id,
1489-
eos_token_id=generation_config.eos_token_id,
1490-
output_scores=generation_config.output_scores,
1491-
return_dict_in_generate=generation_config.return_dict_in_generate,
1486+
generation_config=generation_config,
14921487
synced_gpus=synced_gpus,
14931488
streamer=streamer,
14941489
**model_kwargs,
@@ -1506,15 +1501,12 @@ def generate(
15061501
)
15071502

15081503
# 12. run sample
1509-
outputs = self.sample(
1504+
outputs = self._sample(
15101505
input_ids,
15111506
logits_processor=logits_processor,
15121507
logits_warper=logits_warper,
15131508
stopping_criteria=stopping_criteria,
1514-
pad_token_id=generation_config.pad_token_id,
1515-
eos_token_id=generation_config.eos_token_id,
1516-
output_scores=generation_config.output_scores,
1517-
return_dict_in_generate=generation_config.return_dict_in_generate,
1509+
generation_config=generation_config,
15181510
synced_gpus=synced_gpus,
15191511
streamer=streamer,
15201512
**model_kwargs,
@@ -2198,8 +2190,8 @@ def _prepare_text_encoder_kwargs_for_generation(
21982190
self,
21992191
inputs_tensor: torch.Tensor,
22002192
model_kwargs,
2201-
model_input_name: Optional[str] = None,
2202-
guidance_scale: Optional[float] = None,
2193+
model_input_name: Optional[str],
2194+
generation_config: GenerationConfig,
22032195
) -> Dict[str, Any]:
22042196
# 1. get text encoder
22052197
encoder = self.get_text_encoder()
@@ -2221,6 +2213,9 @@ def _prepare_text_encoder_kwargs_for_generation(
22212213
encoder_kwargs = {
22222214
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
22232215
}
2216+
encoder_kwargs["output_attentions"] = generation_config.output_attentions
2217+
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
2218+
guidance_scale = generation_config.guidance_scale
22242219

22252220
# 3. make sure that encoder returns `ModelOutput`
22262221
model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
@@ -2452,8 +2447,6 @@ def generate(
24522447
batch_size = inputs_tensor.shape[0]
24532448

24542449
# 4. Define other model kwargs
2455-
model_kwargs["output_attentions"] = generation_config.output_attentions
2456-
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
24572450
model_kwargs["use_cache"] = generation_config.use_cache
24582451
model_kwargs["guidance_scale"] = generation_config.guidance_scale
24592452

@@ -2467,10 +2460,7 @@ def generate(
24672460
if "encoder_outputs" not in model_kwargs:
24682461
# encoder_outputs are created and added to `model_kwargs`
24692462
model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
2470-
inputs_tensor,
2471-
model_kwargs,
2472-
model_input_name,
2473-
guidance_scale=generation_config.guidance_scale,
2463+
inputs_tensor, model_kwargs, model_input_name, generation_config,
24742464
)
24752465

24762466
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
@@ -2579,14 +2569,11 @@ def generate(
25792569
)
25802570

25812571
# 11. run greedy search
2582-
outputs = self.greedy_search(
2572+
outputs = self._greedy_search(
25832573
input_ids,
25842574
logits_processor=logits_processor,
25852575
stopping_criteria=stopping_criteria,
2586-
pad_token_id=generation_config.pad_token_id,
2587-
eos_token_id=generation_config.eos_token_id,
2588-
output_scores=generation_config.output_scores,
2589-
return_dict_in_generate=generation_config.return_dict_in_generate,
2576+
generation_config=generation_config,
25902577
synced_gpus=synced_gpus,
25912578
streamer=streamer,
25922579
**model_kwargs,
@@ -2605,15 +2592,12 @@ def generate(
26052592
)
26062593

26072594
# 12. run sample
2608-
outputs = self.sample(
2595+
outputs = self._sample(
26092596
input_ids,
26102597
logits_processor=logits_processor,
26112598
logits_warper=logits_warper,
26122599
stopping_criteria=stopping_criteria,
2613-
pad_token_id=generation_config.pad_token_id,
2614-
eos_token_id=generation_config.eos_token_id,
2615-
output_scores=generation_config.output_scores,
2616-
return_dict_in_generate=generation_config.return_dict_in_generate,
2600+
generation_config=generation_config,
26172601
synced_gpus=synced_gpus,
26182602
streamer=streamer,
26192603
**model_kwargs,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
_deps = [
20-
"transformers>=4.34.0",
20+
"transformers>=4.39.0,<4.41.0",
2121
"torch",
2222
"sentencepiece",
2323
"descript-audio-codec",

training/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Training Parler-TTS
22

3+
<a target="_blank" href="https://colab.research.google.com/github/ylacombe/scripts_and_notebooks/blob/main/Finetuning_Parler_TTS_on_a_single_speaker_dataset.ipynb">
4+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
5+
</a>
6+
37
**TL;DR:** After having followed the [installation steps](#requirements), you can reproduce the [Parler-TTS Mini v0.1](https://huggingface.co/parler-tts/parler_tts_mini_v0.1) training recipe with the following command line:
48

59
```sh
@@ -13,6 +17,8 @@ This sub-folder contains all the information to train or fine-tune your own Parl
1317
- [2. First steps to get started](#b-getting-started)
1418
- [3. Training guide](#c-training)
1519

20+
> [!IMPORTANT]
21+
> You can also follow [this fine-tuning guide](https://colab.research.google.com/github/ylacombe/scripts_and_notebooks/blob/main/Finetuning_Parler_TTS_on_a_single_speaker_dataset.ipynb) on a mono-speaker dataset example.
1622
1723
## 1. Architecture
1824

training/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)