Skip to content

Commit b2b749d

Browse files
authored
Merge pull request #7 from bghira/documentation/mps-xpu-example
add mps and xpu to examples
2 parents 840156b + 99f6e9a commit b2b749d

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

README.md

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ Parler-TTS has light-weight dependencies and can be installed in one line:
2525
pip install git+https://github.com/huggingface/parler-tts.git
2626
```
2727

28+
Apple Silicon users will need to run a follow-up command to make use the nightly PyTorch (2.4) build for bfloat16 support:
29+
30+
```sh
31+
pip3 install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
32+
```
33+
2834
## Usage
2935

3036
> [!TIP]
@@ -38,9 +44,16 @@ from transformers import AutoTokenizer
3844
import soundfile as sf
3945
import torch
4046

41-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
47+
device = "cpu"
48+
if torch.cuda.is_available():
49+
device = "cuda:0"
50+
if torch.backends.mps.is_available():
51+
device = "mps"
52+
if torch.xpu.is_available():
53+
device = "xpu"
54+
torch_dtype = torch.float16 if device != "cpu" else torch.float32
4255

43-
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device)
56+
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)
4457
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
4558

4659
prompt = "Hey, how are you doing today?"
@@ -49,7 +62,7 @@ description = "A female speaker with a slightly low-pitched voice delivers her w
4962
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
5063
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
5164

52-
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
65+
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids).to(torch.float32)
5366
audio_arr = generation.cpu().numpy().squeeze()
5467
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
5568
```

0 commit comments

Comments
 (0)