@@ -25,6 +25,12 @@ Parler-TTS has light-weight dependencies and can be installed in one line:
25
25
pip install git+https://github.com/huggingface/parler-tts.git
26
26
```
27
27
28
+ Apple Silicon users will need to run a follow-up command to make use the nightly PyTorch (2.4) build for bfloat16 support:
29
+
30
+ ``` sh
31
+ pip3 install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
32
+ ```
33
+
28
34
## Usage
29
35
30
36
> [ !TIP]
@@ -38,9 +44,16 @@ from transformers import AutoTokenizer
38
44
import soundfile as sf
39
45
import torch
40
46
41
- device = " cuda:0" if torch.cuda.is_available() else " cpu"
47
+ device = " cpu"
48
+ if torch.cuda.is_available():
49
+ device = " cuda:0"
50
+ if torch.backends.mps.is_available():
51
+ device = " mps"
52
+ if torch.xpu.is_available():
53
+ device = " xpu"
54
+ torch_dtype = torch.float16 if device != " cpu" else torch.float32
42
55
43
- model = ParlerTTSForConditionalGeneration.from_pretrained(" parler-tts/parler_tts_mini_v0.1" ).to(device)
56
+ model = ParlerTTSForConditionalGeneration.from_pretrained(" parler-tts/parler_tts_mini_v0.1" ).to(device, dtype = torch_dtype )
44
57
tokenizer = AutoTokenizer.from_pretrained(" parler-tts/parler_tts_mini_v0.1" )
45
58
46
59
prompt = " Hey, how are you doing today?"
@@ -49,7 +62,7 @@ description = "A female speaker with a slightly low-pitched voice delivers her w
49
62
input_ids = tokenizer(description, return_tensors = " pt" ).input_ids.to(device)
50
63
prompt_input_ids = tokenizer(prompt, return_tensors = " pt" ).input_ids.to(device)
51
64
52
- generation = model.generate(input_ids = input_ids, prompt_input_ids = prompt_input_ids)
65
+ generation = model.generate(input_ids = input_ids, prompt_input_ids = prompt_input_ids).to(torch.float32)
53
66
audio_arr = generation.cpu().numpy().squeeze()
54
67
sf.write(" parler_tts_out.wav" , audio_arr, model.config.sampling_rate)
55
68
```
0 commit comments