Skip to content

Commit 3e265b1

Browse files
authored
Add kokoro support (#94)
* Add kokoro support * Update demo to use kokoro * Use `am_michael` instead of `am_adam` * Install kokoro deps in dockerfile * Revert changes for outetts * Use 0.0 as pad. Add if * enh(demo): Use HF_SPACE env var in order to chose model
1 parent 12b3b0e commit 3e265b1

File tree

13 files changed

+1443
-19
lines changed

13 files changed

+1443
-19
lines changed

demo/Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ RUN apt-get update && apt-get install --no-install-recommends -y \
88
git \
99
&& apt-get clean && rm -rf /var/lib/apt/lists/*
1010

11+
RUN apt-get install espeak-ng -y
12+
1113
RUN useradd -m -u 1000 user
1214

1315
USER user
@@ -18,6 +20,7 @@ ENV HOME=/home/user \
1820
WORKDIR $HOME/app
1921

2022
RUN pip3 install https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.4-cu122/llama_cpp_python-0.3.4-cp310-cp310-linux_x86_64.whl
23+
RUN pip3 install phonemizer
2124
RUN pip3 install document-to-podcast
2225

2326
COPY --chown=user . $HOME/app

demo/app.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Streamlit app for converting documents to podcasts."""
22

3+
import io
4+
import os
35
import re
46
from pathlib import Path
5-
import io
67

78
import numpy as np
89
import soundfile as sf
@@ -28,7 +29,10 @@ def load_text_to_text_model():
2829

2930
@st.cache_resource
3031
def load_text_to_speech_model():
31-
return load_tts_model("OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf")
32+
if os.environ.get("HF_SPACE") == "TRUE":
33+
return load_tts_model("hexgrad/Kokoro-82M/kokoro-v0_19.pth")
34+
else:
35+
return load_tts_model("OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf")
3236

3337

3438
def numpy_to_wav(audio_array: np.ndarray, sample_rate: int) -> io.BytesIO:
@@ -115,10 +119,15 @@ def gen_button_clicked():
115119
text_model = load_text_to_text_model()
116120
speech_model = load_text_to_speech_model()
117121

122+
if os.environ.get("HF_SPACE") == "TRUE":
123+
tts_link = "- [hexgrad/Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M)"
124+
else:
125+
tts_link = "- [OuteAI/OuteTTS-0.2-500M](https://huggingface.co/OuteAI/OuteTTS-0.2-500M-GGUF)"
126+
118127
st.markdown(
119128
"For this demo, we are using the following models: \n"
120129
"- [Qwen2.5-3B-Instruct](https://huggingface.co/bartowski/Qwen2.5-3B-Instruct-GGUF)\n"
121-
"- [OuteAI/OuteTTS-0.2-500M](https://huggingface.co/OuteAI/OuteTTS-0.2-500M-GGUF)"
130+
f"{tts_link}\n"
122131
)
123132
st.markdown(
124133
"You can check the [Customization Guide](https://mozilla-ai.github.io/document-to-podcast/customization/)"
@@ -187,7 +196,7 @@ def gen_button_clicked():
187196

188197
if st.session_state[gen_button]:
189198
audio_np = stack_audio_segments(
190-
st.session_state.audio, speech_model.sample_rate
199+
st.session_state.audio, speech_model.sample_rate, silence_pad=0.0
191200
)
192201
audio_wav = numpy_to_wav(audio_np, speech_model.sample_rate)
193202
if st.download_button(

demo/notebook.ipynb

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,18 @@
7979
"metadata": {},
8080
"outputs": [],
8181
"source": [
82-
"%pip install --quiet https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.4-cu122/llama_cpp_python-0.3.4-cp310-cp310-linux_x86_64.whl\n",
83-
"%pip install --quiet document-to-podcast"
82+
"%pip install --quiet https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.4-cu122/llama_cpp_python-0.3.4-cp311-cp311-linux_x86_64.whl\n",
83+
"%pip install --quiet git+https://github.com/mozilla-ai/document-to-podcast.git@text-to-speech-model\n",
84+
"%pip install --quiet phonemizer"
85+
]
86+
},
87+
{
88+
"cell_type": "code",
89+
"execution_count": null,
90+
"metadata": {},
91+
"outputs": [],
92+
"source": [
93+
"!apt-get -qq -y install espeak-ng"
8494
]
8595
},
8696
{
@@ -173,7 +183,7 @@
173183
"source": [
174184
"For this demo, we are using the following models:\n",
175185
" - [Qwen2.5-3B-Instruct](https://huggingface.co/bartowski/Qwen2.5-3B-Instruct-GGUF)\n",
176-
" - [OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf](https://huggingface.co/OuteAI/OuteTTS-0.2-500M-GGUF)"
186+
" - [hexgrad/Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M)"
177187
]
178188
},
179189
{
@@ -197,7 +207,7 @@
197207
"text_model = load_llama_cpp_model(\n",
198208
" \"bartowski/Qwen2.5-3B-Instruct-GGUF/Qwen2.5-3B-Instruct-f16.gguf\"\n",
199209
")\n",
200-
"speech_model = load_tts_model(\"OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf\")"
210+
"speech_model = load_tts_model(\"hexgrad/Kokoro-82M/kokoro-v0_19.pth\")"
201211
]
202212
},
203213
{
@@ -247,15 +257,15 @@
247257
"speakers = [\n",
248258
" {\n",
249259
" \"id\": 1,\n",
250-
" \"name\": \"Laura\",\n",
260+
" \"name\": \"Sarah\",\n",
251261
" \"description\": \"The main host. She explains topics clearly using anecdotes and analogies, teaching in an engaging and captivating way.\",\n",
252-
" \"voice_profile\": \"female_1\",\n",
262+
" \"voice_profile\": \"af_sarah\",\n",
253263
" },\n",
254264
" {\n",
255265
" \"id\": 2,\n",
256-
" \"name\": \"Jon\",\n",
266+
" \"name\": \"Michael\",\n",
257267
" \"description\": \"The co-host. He keeps the conversation on track, asks curious follow-up questions, and reacts with excitement or confusion, often using interjections like hmm or umm.\",\n",
258-
" \"voice_profile\": \"male_1\",\n",
268+
" \"voice_profile\": \"am_michael\",\n",
259269
" },\n",
260270
"]\n",
261271
"\n",

docs/index.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ These docs are your companion to mastering the **Document-to-Podcast Blueprint**
1111
### Built with
1212
- Python 3.10+
1313
- [Llama-cpp](https://github.com/abetlen/llama-cpp-python) (text-to-text, i.e script generation)
14-
- [OuteAI](https://github.com/edwko/OuteTTS) (text-to-speech, i.e audio generation)
1514
- [Streamlit](https://streamlit.io/) (UI demo)
1615

1716

src/document_to_podcast/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def document_to_podcast(
5252
- {output_folder}/podcast.txt
5353
- {output_folder}/podcast.wav
5454
55-
text_to_text_model (str, optional): The path to the text-to-text model.
55+
text_to_text_model (str, optional): The text-to-text model_id.
5656
5757
Need to be formatted as `owner/repo/file`.
5858
@@ -63,8 +63,8 @@ def document_to_podcast(
6363
text_to_text_prompt (str, optional): The prompt for the text-to-text model.
6464
Defaults to DEFAULT_PROMPT.
6565
66-
text_to_speech_model (str, optional): The path to the text-to-speech model.
67-
Defaults to `OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf`.
66+
text_to_speech_model (str, optional): The text-to-speech model_id.
67+
Defaults to `OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf`.
6868
6969
speakers (list[Speaker] | None, optional): The speakers for the podcast.
7070
Defaults to DEFAULT_SPEAKERS.

src/document_to_podcast/inference/kokoro/__init__.py

Whitespace-only changes.
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import phonemizer
2+
import re
3+
import torch
4+
5+
6+
def split_num(num):
7+
num = num.group()
8+
if "." in num:
9+
return num
10+
elif ":" in num:
11+
h, m = [int(n) for n in num.split(":")]
12+
if m == 0:
13+
return f"{h} o'clock"
14+
elif m < 10:
15+
return f"{h} oh {m}"
16+
return f"{h} {m}"
17+
year = int(num[:4])
18+
if year < 1100 or year % 1000 < 10:
19+
return num
20+
left, right = num[:2], int(num[2:4])
21+
s = "s" if num.endswith("s") else ""
22+
if 100 <= year % 1000 <= 999:
23+
if right == 0:
24+
return f"{left} hundred{s}"
25+
elif right < 10:
26+
return f"{left} oh {right}{s}"
27+
return f"{left} {right}{s}"
28+
29+
30+
def flip_money(m):
31+
m = m.group()
32+
bill = "dollar" if m[0] == "$" else "pound"
33+
if m[-1].isalpha():
34+
return f"{m[1:]} {bill}s"
35+
elif "." not in m:
36+
s = "" if m[1:] == "1" else "s"
37+
return f"{m[1:]} {bill}{s}"
38+
b, c = m[1:].split(".")
39+
s = "" if b == "1" else "s"
40+
c = int(c.ljust(2, "0"))
41+
coins = (
42+
f"cent{'' if c == 1 else 's'}"
43+
if m[0] == "$"
44+
else ("penny" if c == 1 else "pence")
45+
)
46+
return f"{b} {bill}{s} and {c} {coins}"
47+
48+
49+
def point_num(num):
50+
a, b = num.group().split(".")
51+
return " point ".join([a, " ".join(b)])
52+
53+
54+
def normalize_text(text):
55+
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
56+
text = text.replace("«", chr(8220)).replace("»", chr(8221))
57+
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
58+
text = text.replace("(", "«").replace(")", "»")
59+
for a, b in zip("、。!,:;?", ",.!,:;?"):
60+
text = text.replace(a, b + " ")
61+
text = re.sub(r"[^\S \n]", " ", text)
62+
text = re.sub(r" +", " ", text)
63+
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
64+
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
65+
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
66+
text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
67+
text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
68+
text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
69+
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
70+
text = re.sub(
71+
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
72+
)
73+
text = re.sub(r"(?<=\d),(?=\d)", "", text)
74+
text = re.sub(
75+
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
76+
flip_money,
77+
text,
78+
)
79+
text = re.sub(r"\d*\.\d+", point_num, text)
80+
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
81+
text = re.sub(r"(?<=\d)S", " S", text)
82+
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
83+
text = re.sub(r"(?<=X')S\b", "s", text)
84+
text = re.sub(
85+
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
86+
)
87+
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
88+
return text.strip()
89+
90+
91+
def get_vocab():
92+
_pad = "$"
93+
_punctuation = ';:,.!?¡¿—…"«»“” '
94+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
95+
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
96+
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
97+
dicts = {}
98+
for i in range(len((symbols))):
99+
dicts[symbols[i]] = i
100+
return dicts
101+
102+
103+
VOCAB = get_vocab()
104+
105+
106+
def tokenize(ps):
107+
return [i for i in map(VOCAB.get, ps) if i is not None]
108+
109+
110+
phonemizers = dict(
111+
a=phonemizer.backend.EspeakBackend(
112+
language="en-us", preserve_punctuation=True, with_stress=True
113+
),
114+
b=phonemizer.backend.EspeakBackend(
115+
language="en-gb", preserve_punctuation=True, with_stress=True
116+
),
117+
)
118+
119+
120+
def phonemize(text, lang, norm=True):
121+
if norm:
122+
text = normalize_text(text)
123+
ps = phonemizers[lang].phonemize([text])
124+
ps = ps[0] if ps else ""
125+
# https://en.wiktionary.org/wiki/kokoro#English
126+
ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
127+
ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
128+
ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
129+
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps)
130+
if lang == "a":
131+
ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
132+
ps = "".join(filter(lambda p: p in VOCAB, ps))
133+
return ps.strip()
134+
135+
136+
def length_to_mask(lengths):
137+
mask = (
138+
torch.arange(lengths.max())
139+
.unsqueeze(0)
140+
.expand(lengths.shape[0], -1)
141+
.type_as(lengths)
142+
)
143+
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
144+
return mask
145+
146+
147+
@torch.no_grad()
148+
def forward(model, tokens, ref_s, speed):
149+
device = ref_s.device
150+
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
151+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
152+
text_mask = length_to_mask(input_lengths).to(device)
153+
bert_dur = model["bert"](tokens, attention_mask=(~text_mask).int())
154+
d_en = model["bert_encoder"](bert_dur).transpose(-1, -2)
155+
s = ref_s[:, 128:]
156+
d = model["predictor"].text_encoder(d_en, s, input_lengths, text_mask)
157+
x, _ = model["predictor"].lstm(d)
158+
duration = model["predictor"].duration_proj(x)
159+
duration = torch.sigmoid(duration).sum(axis=-1) / speed
160+
pred_dur = torch.round(duration).clamp(min=1).long()
161+
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
162+
c_frame = 0
163+
for i in range(pred_aln_trg.size(0)):
164+
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
165+
c_frame += pred_dur[0, i].item()
166+
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
167+
F0_pred, N_pred = model["predictor"].F0Ntrain(en, s)
168+
t_en = model["text_encoder"](tokens, input_lengths, text_mask)
169+
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
170+
return (
171+
model["decoder"](asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
172+
)
173+
174+
175+
def generate(model, text, voicepack, lang="a", speed=1, ps=None):
176+
ps = ps or phonemize(text, lang)
177+
tokens = tokenize(ps)
178+
if not tokens:
179+
return None
180+
elif len(tokens) > 510:
181+
tokens = tokens[:510]
182+
print("Truncated to 510 tokens")
183+
ref_s = voicepack[len(tokens)]
184+
return forward(model, tokens, ref_s, speed)

0 commit comments

Comments
 (0)