Skip to content

Commit a980618

Browse files
Adds example for training a TTS model on top of a LLM. (axolotl-ai-cloud#2614)
* Adds example for training a TTS model on top of a LLM. * Update examples/orpheus/finetune.yml Co-authored-by: NanoCode012 <[email protected]> * Update examples/orpheus/finetune.yml Co-authored-by: NanoCode012 <[email protected]> * Update README.md to clarify GPU requirements for finetuning Orpheus TTS model * Update finetune.yml to use the new base model canopylabs/orpheus-3b-0.1-pretrained * Update finetune.yml and README.md for consistency and clarity --------- Co-authored-by: NanoCode012 <[email protected]>
1 parent 54960d4 commit a980618

File tree

2 files changed

+393
-0
lines changed

2 files changed

+393
-0
lines changed

examples/orpheus/README.md

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
# Finetuning LLMs to output audio
2+
3+
In this example, we finetune Orpcanopylabs/orpheus-tts-0.1-pretrained (a LLaMA 3.2 3b model) to output audio.
4+
5+
The `finetune.yml` withe current settings will run on any Nvidia GPU with 45GB VRAM or more. If you adjust the batch size it can easily run on any GPU under 24GB.
6+
7+
## Dataset pre-processing for pre-training
8+
If you are adding another voice in English, please jump ahead to finetuning pre-processing.
9+
10+
For this to work, we need to preprocess our dataset. Since we are expecting to output audio, we will need to add tokens to the tokenizer.
11+
12+
Using this code, it will download the SNAC model and add the correct tokens and upload the final dataset.
13+
14+
```python
15+
import torch
16+
from snac import SNAC
17+
from datasets import load_dataset
18+
from huggingface_hub import snapshot_download
19+
from datasets import load_dataset
20+
import random
21+
import torchaudio.transforms as T
22+
from transformers import AutoTokenizer
23+
import os
24+
25+
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
26+
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
27+
28+
dsn = my_original_dataset_name
29+
30+
snapshot_download(
31+
repo_id=dsn,
32+
repo_type="dataset",
33+
revision="main",
34+
max_workers=64,
35+
)
36+
37+
38+
ds = load_dataset(dsn, split="train")
39+
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
40+
41+
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
42+
model = model.to("mps")
43+
44+
def tokenise_audio(waveform):
45+
waveform = torch.from_numpy(waveform).unsqueeze(0)
46+
waveform = waveform.to(dtype=torch.float32)
47+
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
48+
waveform = resample_transform(waveform)
49+
50+
waveform = waveform.unsqueeze(0).to("cuda")
51+
52+
#generate the codes from snac
53+
with torch.inference_mode():
54+
codes = model.encode(waveform)
55+
56+
all_codes = []
57+
for i in range(codes[0].shape[1]):
58+
all_codes.append(codes[0][0][i].item()+128266)
59+
all_codes.append(codes[1][0][2*i].item()+128266+4096)
60+
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
61+
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
62+
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
63+
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
64+
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
65+
66+
67+
return all_codes
68+
69+
def add_codes(example):
70+
# Always initialize codes_list to None
71+
codes_list = None
72+
73+
try:
74+
answer_audio = example.get("audio")
75+
# If there's a valid audio array, tokenise it
76+
if answer_audio and "array" in answer_audio:
77+
audio_array = answer_audio["array"]
78+
codes_list = tokenise_audio(audio_array)
79+
except Exception as e:
80+
print(f"Skipping row due to error: {e}")
81+
# Keep codes_list as None if we fail
82+
example["codes_list"] = codes_list
83+
84+
return example
85+
86+
ds = ds.map(add_codes, remove_columns=["audio"])
87+
88+
#@title Load Tokenizer
89+
tokeniser_length = 128256
90+
start_of_text = 128000
91+
end_of_text = 128009
92+
93+
start_of_speech = tokeniser_length + 1
94+
end_of_speech = tokeniser_length + 2
95+
96+
start_of_human = tokeniser_length + 3
97+
end_of_human = tokeniser_length + 4
98+
99+
start_of_ai = tokeniser_length + 5
100+
end_of_ai = tokeniser_length + 6
101+
pad_token = tokeniser_length + 7
102+
103+
audio_tokens_start = tokeniser_length + 10
104+
105+
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
106+
107+
108+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
109+
num_proc = os.cpu_count() - 2
110+
111+
ds = ds.filter(lambda x: x["codes_list"] is not None)
112+
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
113+
114+
#@title Create Input Ids
115+
def remove_duplicate_frames(example):
116+
vals = example["codes_list"]
117+
if len(vals) % 7 != 0:
118+
raise ValueError("Input list length must be divisible by 7")
119+
120+
result = vals[:7]
121+
122+
removed_frames = 0
123+
124+
for i in range(7, len(vals), 7):
125+
current_first = vals[i]
126+
previous_first = result[-7]
127+
128+
if current_first != previous_first:
129+
result.extend(vals[i:i+7])
130+
else:
131+
removed_frames += 1
132+
133+
example["codes_list"] = result
134+
135+
return example
136+
137+
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
138+
139+
140+
def create_input_ids(example):
141+
text_ids = tokenizer.encode({example['text']}, add_special_tokens=True)
142+
text_ids.append(end_of_text)
143+
example["text_tokens"] = text_ids
144+
input_ids = (
145+
[start_of_human]
146+
+ example["text_tokens"]
147+
+ [end_of_human]
148+
+ [start_of_ai]
149+
+ [start_of_speech]
150+
+ example["codes_list"]
151+
+ [end_of_speech]
152+
+ [end_of_ai]
153+
)
154+
example["input_ids"] = input_ids
155+
example["labels"] = input_ids
156+
example["attention_mask"] = [1] * len(input_ids)
157+
158+
return example
159+
160+
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
161+
162+
#@title Remove unnecessary columns
163+
columns_to_keep = ["input_ids", "labels", "attention_mask"]
164+
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
165+
166+
ds = ds.remove_columns(columns_to_remove)
167+
168+
ds.push_to_hub(name_to_push_dataset_to)
169+
```
170+
171+
172+
## Finetune pre-processing
173+
Use this code to add a new voice.
174+
175+
```python
176+
import torch
177+
from snac import SNAC
178+
from datasets import load_dataset
179+
from huggingface_hub import snapshot_download
180+
from datasets import load_dataset
181+
import random
182+
import torchaudio.transforms as T
183+
from transformers import AutoTokenizer
184+
import os
185+
186+
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
187+
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
188+
189+
dsn = my_original_dataset_name
190+
191+
snapshot_download(
192+
repo_id=dsn,
193+
repo_type="dataset",
194+
revision="main",
195+
max_workers=64,
196+
)
197+
198+
199+
ds = load_dataset(dsn, split="train")
200+
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
201+
202+
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
203+
model = model.to("mps")
204+
205+
def tokenise_audio(waveform):
206+
waveform = torch.from_numpy(waveform).unsqueeze(0)
207+
waveform = waveform.to(dtype=torch.float32)
208+
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
209+
waveform = resample_transform(waveform)
210+
211+
waveform = waveform.unsqueeze(0).to("cuda")
212+
213+
#generate the codes from snac
214+
with torch.inference_mode():
215+
codes = model.encode(waveform)
216+
217+
all_codes = []
218+
for i in range(codes[0].shape[1]):
219+
all_codes.append(codes[0][0][i].item()+128266)
220+
all_codes.append(codes[1][0][2*i].item()+128266+4096)
221+
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
222+
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
223+
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
224+
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
225+
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
226+
227+
228+
return all_codes
229+
230+
def add_codes(example):
231+
# Always initialize codes_list to None
232+
codes_list = None
233+
234+
try:
235+
answer_audio = example.get("audio")
236+
# If there's a valid audio array, tokenise it
237+
if answer_audio and "array" in answer_audio:
238+
audio_array = answer_audio["array"]
239+
codes_list = tokenise_audio(audio_array)
240+
except Exception as e:
241+
print(f"Skipping row due to error: {e}")
242+
# Keep codes_list as None if we fail
243+
example["codes_list"] = codes_list
244+
245+
return example
246+
247+
ds = ds.map(add_codes, remove_columns=["audio"])
248+
249+
#@title Load Tokenizer
250+
tokeniser_length = 128256
251+
start_of_text = 128000
252+
end_of_text = 128009
253+
254+
start_of_speech = tokeniser_length + 1
255+
end_of_speech = tokeniser_length + 2
256+
257+
start_of_human = tokeniser_length + 3
258+
end_of_human = tokeniser_length + 4
259+
260+
start_of_ai = tokeniser_length + 5
261+
end_of_ai = tokeniser_length + 6
262+
pad_token = tokeniser_length + 7
263+
264+
audio_tokens_start = tokeniser_length + 10
265+
266+
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
267+
268+
269+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
270+
num_proc = os.cpu_count() - 2
271+
272+
ds = ds.filter(lambda x: x["codes_list"] is not None)
273+
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
274+
275+
#@title Create Input Ids
276+
def remove_duplicate_frames(example):
277+
vals = example["codes_list"]
278+
if len(vals) % 7 != 0:
279+
raise ValueError("Input list length must be divisible by 7")
280+
281+
result = vals[:7]
282+
283+
removed_frames = 0
284+
285+
for i in range(7, len(vals), 7):
286+
current_first = vals[i]
287+
previous_first = result[-7]
288+
289+
if current_first != previous_first:
290+
result.extend(vals[i:i+7])
291+
else:
292+
removed_frames += 1
293+
294+
example["codes_list"] = result
295+
296+
return example
297+
298+
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
299+
300+
tok_info = '''*** HERE you can modify the text prompt
301+
i.e. if you wanted a multispeaker model like canopylabs/orpheus-3b-0.1-ft, you can pass:
302+
f"{example["source"]}: {example["text"]}", as is passed.
303+
'''
304+
print(tok_info)
305+
306+
def create_input_ids(example):
307+
text_ids = tokenizer.encode(f"{example['speaker_id']}: {example['text']}", add_special_tokens=True)
308+
text_ids.append(end_of_text)
309+
example["text_tokens"] = text_ids
310+
input_ids = (
311+
[start_of_human]
312+
+ example["text_tokens"]
313+
+ [end_of_human]
314+
+ [start_of_ai]
315+
+ [start_of_speech]
316+
+ example["codes_list"]
317+
+ [end_of_speech]
318+
+ [end_of_ai]
319+
)
320+
example["input_ids"] = input_ids
321+
example["labels"] = input_ids
322+
example["attention_mask"] = [1] * len(input_ids)
323+
324+
return example
325+
326+
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
327+
328+
#@title Remove unnecessary columns
329+
columns_to_keep = ["input_ids", "labels", "attention_mask"]
330+
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
331+
332+
ds = ds.remove_columns(columns_to_remove)
333+
334+
ds.push_to_hub(name_to_push_dataset_to)
335+
```
336+
337+
## Training
338+
After preprocessing is done, fill out the blanks in finetune.yml and simply run `axolotl train finetune.yml`
339+
340+
## Inference
341+
For inference, please refer to the original [orpheus github](https://github.com/canopyai/Orpheus-TTS/tree/main).

examples/orpheus/finetune.yml

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
base_model: canopylabs/orpheus-3b-0.1-pretrained
2+
3+
hub_model_id: <your-hub-model-id>
4+
5+
plugins:
6+
- axolotl.integrations.liger.LigerPlugin
7+
liger_rope: true
8+
liger_rms_norm: true
9+
liger_glu_activation: true
10+
liger_fused_linear_cross_entropy: true
11+
12+
datasets:
13+
- path: <your-hf-dataset-id>
14+
type: # leave empty to load pre-tokenized
15+
dataset_prepared_path: last_run_prepared
16+
val_set_size: 0.01
17+
output_dir: ./outputs/out
18+
19+
sequence_len: 8192
20+
sample_packing: true
21+
pad_to_sequence_len: true
22+
23+
wandb_project:
24+
wandb_entity:
25+
wandb_watch:
26+
wandb_name:
27+
wandb_log_model:
28+
29+
gradient_accumulation_steps: 8
30+
micro_batch_size: 4
31+
num_epochs: 3
32+
optimizer: adamw_torch_fused
33+
lr_scheduler: cosine
34+
learning_rate: 2e-5
35+
36+
bf16: auto
37+
tf32: false
38+
39+
gradient_checkpointing: true
40+
gradient_checkpointing_kwargs:
41+
use_reentrant: false
42+
resume_from_checkpoint:
43+
logging_steps: 1
44+
flash_attention: true
45+
46+
warmup_steps: 20
47+
evals_per_epoch: 5
48+
saves_per_epoch: 5
49+
weight_decay: 0.05
50+
51+
special_tokens:
52+
pad_token: <custom_token_7>

0 commit comments

Comments
 (0)