Skip to content

Commit f1fc984

Browse files
authored
Merge branch 'FunAudioLLM:main' into main
2 parents 0ca50ec + 95e99e0 commit f1fc984

File tree

16 files changed

+198
-54
lines changed

16 files changed

+198
-54
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,16 @@ for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒
150150
# instruct usage
151151
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
152152
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
153+
154+
# bistream usage, you can use generator as input, this is useful when using text llm model as input
155+
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
156+
def text_generator():
157+
yield '收到好友从远方寄来的生日礼物,'
158+
yield '那份意外的惊喜与深深的祝福'
159+
yield '让我心中充满了甜蜜的快乐,'
160+
yield '笑容如花儿般绽放。'
161+
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
162+
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
153163
```
154164

155165
**CosyVoice Usage**

cosyvoice/cli/cosyvoice.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
import time
16+
from typing import Generator
1617
from tqdm import tqdm
1718
from hyperpyyaml import load_hyperpyyaml
1819
from modelscope import snapshot_download
@@ -79,7 +80,7 @@ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend
7980
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
8081
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
8182
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
82-
if len(i) < 0.5 * len(prompt_text):
83+
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
8384
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
8485
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
8586
start_time = time.time()

cosyvoice/cli/frontend.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import partial
15+
from typing import Generator
1516
import json
1617
import onnxruntime
1718
import torch
@@ -31,6 +32,7 @@
3132
from tn.chinese.normalizer import Normalizer as ZhNormalizer
3233
from tn.english.normalizer import Normalizer as EnNormalizer
3334
use_ttsfrd = False
35+
from cosyvoice.utils.file_utils import logging
3436
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
3537

3638

@@ -72,10 +74,21 @@ def __init__(self,
7274
self.inflect_parser = inflect.engine()
7375

7476
def _extract_text_token(self, text):
75-
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
76-
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
77-
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
78-
return text_token, text_token_len
77+
if isinstance(text, Generator):
78+
logging.info('get tts_text generator, will return _extract_text_token_generator!')
79+
# NOTE add a dummy text_token_len for compatibility
80+
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
81+
else:
82+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
83+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
84+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
85+
return text_token, text_token_len
86+
87+
def _extract_text_token_generator(self, text_generator):
88+
for text in text_generator:
89+
text_token, _ = self._extract_text_token(text)
90+
for i in range(text_token.shape[1]):
91+
yield text_token[:, i: i + 1]
7992

8093
def _extract_speech_token(self, speech):
8194
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
@@ -107,6 +120,9 @@ def _extract_speech_feat(self, speech):
107120
return speech_feat, speech_feat_len
108121

109122
def text_normalize(self, text, split=True, text_frontend=True):
123+
if isinstance(text, Generator):
124+
logging.info('get tts_text generator, will skip text_normalize!')
125+
return [text]
110126
if text_frontend is False:
111127
return [text] if split is True else text
112128
text = text.strip()

cosyvoice/cli/model.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from typing import Generator
1516
import torch
1617
import numpy as np
1718
import threading
@@ -99,14 +100,24 @@ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
99100

100101
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
101102
with self.llm_context:
102-
for i in self.llm.inference(text=text.to(self.device),
103-
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
104-
prompt_text=prompt_text.to(self.device),
105-
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
106-
prompt_speech_token=llm_prompt_speech_token.to(self.device),
107-
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
108-
embedding=llm_embedding.to(self.device)):
109-
self.tts_speech_token_dict[uuid].append(i)
103+
if isinstance(text, Generator):
104+
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
105+
for i in self.llm.inference_bistream(text=text,
106+
prompt_text=prompt_text.to(self.device),
107+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
108+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
109+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
110+
embedding=llm_embedding.to(self.device)):
111+
self.tts_speech_token_dict[uuid].append(i)
112+
else:
113+
for i in self.llm.inference(text=text.to(self.device),
114+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
115+
prompt_text=prompt_text.to(self.device),
116+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
117+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
118+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
119+
embedding=llm_embedding.to(self.device)):
120+
self.tts_speech_token_dict[uuid].append(i)
110121
self.llm_end_dict[uuid] = True
111122

112123
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):

cosyvoice/dataset/processor.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torchaudio
2121
from torch.nn.utils.rnn import pad_sequence
2222
import torch.nn.functional as F
23+
import pyworld as pw
2324

2425

2526
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
@@ -178,7 +179,7 @@ def compute_fbank(data,
178179
yield sample
179180

180181

181-
def compute_f0(data, pitch_extractor, mode='train'):
182+
def compute_f0(data, sample_rate, hop_size, mode='train'):
182183
""" Extract f0
183184
184185
Args:
@@ -187,15 +188,19 @@ def compute_f0(data, pitch_extractor, mode='train'):
187188
Returns:
188189
Iterable[{key, feat, label}]
189190
"""
191+
frame_period = hop_size * 1000 / sample_rate
190192
for sample in data:
191193
assert 'sample_rate' in sample
192194
assert 'speech' in sample
193195
assert 'utt' in sample
194196
assert 'text_token' in sample
195197
waveform = sample['speech']
196-
mat = pitch_extractor(waveform).transpose(1, 2)
197-
mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
198-
sample['pitch_feat'] = mat[0, 0]
198+
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
199+
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
200+
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
201+
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
202+
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
203+
sample['pitch_feat'] = f0
199204
yield sample
200205

201206

cosyvoice/llm/llm.py

Lines changed: 110 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from cosyvoice.utils.common import IGNORE_ID
2121
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
2222
from cosyvoice.utils.common import th_accuracy
23+
from cosyvoice.utils.file_utils import logging
2324

2425

2526
class TransformerLM(torch.nn.Module):
@@ -144,10 +145,14 @@ def sampling_ids(
144145
sampling: int,
145146
ignore_eos: bool = True,
146147
):
148+
num_trials, max_trials = 0, 100
147149
while True:
148150
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
149151
if (not ignore_eos) or (self.speech_token_size not in top_ids):
150152
break
153+
num_trials += 1
154+
if num_trials > max_trials:
155+
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
151156
return top_ids
152157

153158
@torch.inference_mode()
@@ -239,7 +244,7 @@ def forward_one_step(self, xs, masks, cache=None):
239244
return xs, new_cache
240245

241246

242-
class Qwen2LM(torch.nn.Module):
247+
class Qwen2LM(TransformerLM):
243248
def __init__(
244249
self,
245250
llm_input_size: int,
@@ -249,8 +254,9 @@ def __init__(
249254
sampling: Callable,
250255
length_normalized_loss: bool = True,
251256
lsm_weight: float = 0.0,
257+
mix_ratio: List[int] = [5, 15],
252258
):
253-
super().__init__()
259+
torch.nn.Module.__init__(self)
254260
self.llm_input_size = llm_input_size
255261
self.llm_output_size = llm_output_size
256262
self.speech_token_size = speech_token_size
@@ -275,23 +281,7 @@ def __init__(
275281

276282
# 4. sampling method
277283
self.sampling = sampling
278-
279-
def sampling_ids(
280-
self,
281-
weighted_scores: torch.Tensor,
282-
decoded_tokens: List,
283-
sampling: int,
284-
ignore_eos: bool = True,
285-
):
286-
num_trials, max_trials = 0, 100
287-
while True:
288-
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
289-
if (not ignore_eos) or (self.speech_token_size not in top_ids):
290-
break
291-
num_trials += 1
292-
if num_trials > max_trials:
293-
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
294-
return top_ids
284+
self.mix_ratio = mix_ratio
295285

296286
@torch.inference_mode()
297287
def inference(
@@ -312,17 +302,14 @@ def inference(
312302
text_len += prompt_text_len
313303
text = self.llm.model.model.embed_tokens(text)
314304

315-
# 2. encode embedding
316-
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
317-
318305
# 3. concat llm_input
319306
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
320307
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
321308
if prompt_speech_token_len != 0:
322309
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
323310
else:
324311
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
325-
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
312+
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
326313

327314
# 4. cal min/max_length
328315
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -345,3 +332,103 @@ def inference(
345332
yield top_ids
346333
out_tokens.append(top_ids)
347334
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
335+
336+
@torch.inference_mode()
337+
def inference_bistream(
338+
self,
339+
text: Generator,
340+
prompt_text: torch.Tensor,
341+
prompt_text_len: torch.Tensor,
342+
prompt_speech_token: torch.Tensor,
343+
prompt_speech_token_len: torch.Tensor,
344+
embedding: torch.Tensor,
345+
sampling: int = 25,
346+
max_token_text_ratio: float = 20,
347+
min_token_text_ratio: float = 2,
348+
) -> Generator[torch.Tensor, None, None]:
349+
350+
device = prompt_text.device
351+
# 1. prepare input
352+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
353+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
354+
if prompt_speech_token_len != 0:
355+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
356+
else:
357+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
358+
lm_input = torch.concat([sos_eos_emb], dim=1)
359+
360+
# 2. iterate text
361+
out_tokens = []
362+
cache = None
363+
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
364+
text_cache = self.llm.model.model.embed_tokens(prompt_text)
365+
next_fill_index = -1
366+
for this_text in text:
367+
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
368+
# prompt_speech_token_emb not empty, try append to lm_input
369+
while prompt_speech_token_emb.size(1) != 0:
370+
if text_cache.size(1) >= self.mix_ratio[0]:
371+
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
372+
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
373+
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
374+
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
375+
else:
376+
logging.info('not enough text token to decode, wait for more')
377+
break
378+
# no prompt_speech_token_emb remain, can decode some speech token
379+
if prompt_speech_token_emb.size(1) == 0:
380+
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
381+
logging.info('get fill token, need to append more text token')
382+
if text_cache.size(1) >= self.mix_ratio[0]:
383+
lm_input_text = text_cache[:, :self.mix_ratio[0]]
384+
logging.info('append {} text token'.format(lm_input_text.size(1)))
385+
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
386+
lm_input = lm_input_text
387+
else:
388+
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
389+
text_cache = text_cache[:, self.mix_ratio[0]:]
390+
else:
391+
logging.info('not enough text token to decode, wait for more')
392+
continue
393+
while True:
394+
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
395+
y_pred, cache = self.llm.forward_one_step(lm_input,
396+
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
397+
cache=cache)
398+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
399+
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
400+
top_ids = self.speech_token_size + 2
401+
next_fill_index += (self.mix_ratio[1] + 1)
402+
else:
403+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
404+
if top_ids == self.speech_token_size + 2:
405+
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
406+
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
407+
out_tokens.append(top_ids)
408+
if top_ids >= self.speech_token_size:
409+
if top_ids == self.speech_token_size + 2:
410+
break
411+
else:
412+
raise ValueError('should not get token {}'.format(top_ids))
413+
yield top_ids
414+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
415+
416+
# 3. final decode
417+
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
418+
logging.info('no more text token, decode until met eos')
419+
while True:
420+
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
421+
y_pred, cache = self.llm.forward_one_step(lm_input,
422+
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
423+
cache=cache)
424+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
425+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
426+
out_tokens.append(top_ids)
427+
if top_ids >= self.speech_token_size:
428+
if top_ids == self.speech_token_size:
429+
break
430+
else:
431+
raise ValueError('should not get token {}'.format(top_ids))
432+
# in stream mode, yield token one by one
433+
yield top_ids
434+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)

cosyvoice/utils/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,5 +162,5 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
162162
# attention mask bias
163163
# NOTE(Mddct): torch.finfo jit issues
164164
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
165-
mask = (1.0 - mask) * torch.finfo(dtype).min
165+
mask = (1.0 - mask) * -1.0e+10
166166
return mask

cosyvoice/utils/mask.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import torch
18+
from cosyvoice.utils.file_utils import logging
1819
'''
1920
def subsequent_mask(
2021
size: int,
@@ -230,6 +231,10 @@ def add_optional_chunk_mask(xs: torch.Tensor,
230231
chunk_masks = masks & chunk_masks # (B, L, L)
231232
else:
232233
chunk_masks = masks
234+
assert chunk_masks.dtype == torch.bool
235+
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
236+
logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
237+
chunk_masks[chunk_masks.sum(dim=-1)==0] = True
233238
return chunk_masks
234239

235240

0 commit comments

Comments
 (0)