Skip to content

Commit 7b3a834

Browse files
authored
Merge pull request #12 from kadirnar/fix-checkpoint-vui
fix: standardize output file naming and fix model loading in VUI inference
2 parents a7e3cd3 + 04a9f6c commit 7b3a834

File tree

5 files changed

+50
-10
lines changed

5 files changed

+50
-10
lines changed

README.md

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,54 @@ uv pip install voicehub
1515

1616
## 📚 Usage
1717

18+
VoiceHub provides a simple, unified interface for working with various Text-to-Speech (TTS) models. Below are examples showing how to use different supported TTS models with the same consistent approach.
19+
20+
### OrpheusTTS Model
21+
1822
```python
1923
from voicehub.automodel import AutoInferenceModel
2024

21-
# Create model using the static from_pretrained method
2225
model = AutoInferenceModel.from_pretrained(
2326
model_type="orpheustts", # or "dia" or "vui"
2427
model_path="canopylabs/orpheus-3b-0.1-ft",
2528
device="cuda",
2629
)
2730

28-
# Generate speech with the model
31+
output = model("Hello, how are you today?", voice="tara", output_file="output.wav")
32+
```
33+
34+
### DiaTTS Model
35+
36+
```python
37+
from voicehub.automodel import AutoInferenceModel
38+
39+
model = AutoInferenceModel.from_pretrained(
40+
model_type="dia", # or "dia" or "vui"
41+
model_path="dia/dia-100m-base.pt",
42+
device="cuda",
43+
)
44+
2945
output = model(
30-
"Hello, how are you today?", voice="tara", output_file="output"
31-
) # voice param is only for orpheustts
46+
text="Hey, here is some random stuff, the text the less likely the model can cope!",
47+
output_file="output.wav",
48+
)
49+
```
50+
51+
### VuiTTS Model
52+
53+
```python
54+
from voicehub.automodel import AutoInferenceModel
55+
56+
model = AutoInferenceModel.from_pretrained(
57+
model_type="vui", # or "dia" or "vui"
58+
model_path="vui-100m-base.pt",
59+
device="cuda",
60+
)
61+
62+
output = model(
63+
text="Hey, here is some random stuff, the text the less likely the model can cope!",
64+
output_file="output.wav",
65+
)
3266
```
3367

3468
## 🤗 Contributing
@@ -43,4 +77,4 @@ pre-commit run --all-files
4377

4478
- [Orpheus-TTS](https://github.com/canopyai/Orpheus-TTS)
4579
- [Dia](https://github.com/nari-labs/dia)
46-
- [VUI](https://github.com/fluxions-ai/vui)
80+
- [Vui](https://github.com/fluxions-ai/vui)

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ torchaudio
55
pydantic
66
descript-audio-codec
77
soundfile
8+
inflect
9+
pandas
10+
pyannote.audio

voicehub/models/orpheustts/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _postprocess_tokens(self, generated_ids: torch.Tensor) -> list:
145145

146146
return adjusted
147147

148-
def __call__(self, prompt: str, voice: str, output_file: str = "sample"):
148+
def __call__(self, prompt: str, voice: str, output_file: str = "output.wav"):
149149
"""
150150
Generate speech from text prompts.
151151
@@ -181,7 +181,7 @@ def __call__(self, prompt: str, voice: str, output_file: str = "sample"):
181181
audio = self._redistribute_codes(codes)
182182
# Save as 24kHz WAV file
183183
sf.write(
184-
f"{output_file}.wav",
184+
f"{output_file}",
185185
audio.detach().squeeze().cpu().numpy(),
186186
24000,
187187
)

voicehub/models/vui/inference.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
import torchaudio
22

3-
from voicehub.models.vui.inference import render
43
from voicehub.models.vui.model import Vui
4+
from voicehub.models.vui.tts import render
55

66

77
class VuiTTS:
88

99
def __init__(self, model_path: str, device: str = "cuda"):
1010
self.model_path = model_path
11+
self.device = device
1112
self.model = None
1213

1314
def load_model(self):
1415
model = Vui.from_pretrained(checkpoint_path=self.model_path).to(self.device)
1516
self.model = model
1617

1718
def __call__(self, text: str, output_file: str = "output.wav"):
19+
if self.model is None:
20+
self.load_model()
1821
waveform = render(self.model, text)
1922
torchaudio.save(output_file, waveform[0], 22050)

voicehub/models/vui/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,8 @@ def from_pretrained(
386386
from huggingface_hub import hf_hub_download
387387

388388
checkpoint_path = hf_hub_download(
389-
"fluxions/vui",
390-
checkpoint_path,
389+
repo_id="fluxions/vui",
390+
filename=checkpoint_path,
391391
)
392392
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
393393

0 commit comments

Comments
 (0)