Skip to content

Commit 49761d2

Browse files
authored
Merge pull request #924 from FunAudioLLM/dev/lyuxiang.lx
add llm bistream
2 parents 41c5e8c + 07e4775 commit 49761d2

File tree

9 files changed

+162
-38
lines changed

9 files changed

+162
-38
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒
143143
# instruct usage
144144
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
145145
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
146+
147+
# bistream usage, you can use generator as input, this is useful when using text llm model as input
148+
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
149+
def text_generator():
150+
yield '收到好友从远方寄来的生日礼物,'
151+
yield '那份意外的惊喜与深深的祝福'
152+
yield '让我心中充满了甜蜜的快乐,'
153+
yield '笑容如花儿般绽放。'
154+
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator, '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
155+
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
146156
```
147157

148158
**CosyVoice Usage**

cosyvoice/cli/cosyvoice.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
import time
16+
from typing import Generator
1617
from tqdm import tqdm
1718
from hyperpyyaml import load_hyperpyyaml
1819
from modelscope import snapshot_download
@@ -76,7 +77,7 @@ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend
7677
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
7778
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
7879
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
79-
if len(i) < 0.5 * len(prompt_text):
80+
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
8081
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
8182
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
8283
start_time = time.time()

cosyvoice/cli/frontend.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import partial
15+
from typing import Generator
1516
import json
1617
import onnxruntime
1718
import torch
@@ -31,6 +32,7 @@
3132
from tn.chinese.normalizer import Normalizer as ZhNormalizer
3233
from tn.english.normalizer import Normalizer as EnNormalizer
3334
use_ttsfrd = False
35+
from cosyvoice.utils.file_utils import logging
3436
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
3537

3638

@@ -71,10 +73,21 @@ def __init__(self,
7173
self.inflect_parser = inflect.engine()
7274

7375
def _extract_text_token(self, text):
74-
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
75-
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
76-
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
77-
return text_token, text_token_len
76+
if isinstance(text, Generator):
77+
logging.info('get tts_text generator, will return _extract_text_token_generator!')
78+
# NOTE add a dummy text_token_len for compatibility
79+
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
80+
else:
81+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
82+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
83+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
84+
return text_token, text_token_len
85+
86+
def _extract_text_token_generator(self, text_generator):
87+
for text in text_generator:
88+
text_token, _ = self._extract_text_token(text)
89+
for i in range(text_token.shape[1]):
90+
yield text_token[:, i: i + 1]
7891

7992
def _extract_speech_token(self, speech):
8093
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
@@ -106,6 +119,9 @@ def _extract_speech_feat(self, speech):
106119
return speech_feat, speech_feat_len
107120

108121
def text_normalize(self, text, split=True, text_frontend=True):
122+
if isinstance(text, Generator):
123+
logging.info('get tts_text generator, will skip text_normalize!')
124+
return [text]
109125
if text_frontend is False:
110126
return [text] if split is True else text
111127
text = text.strip()

cosyvoice/cli/model.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from typing import Generator
1516
import torch
1617
import numpy as np
1718
import threading
@@ -99,14 +100,24 @@ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
99100

100101
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
101102
with self.llm_context:
102-
for i in self.llm.inference(text=text.to(self.device),
103-
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
104-
prompt_text=prompt_text.to(self.device),
105-
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
106-
prompt_speech_token=llm_prompt_speech_token.to(self.device),
107-
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
108-
embedding=llm_embedding.to(self.device)):
109-
self.tts_speech_token_dict[uuid].append(i)
103+
if isinstance(text, Generator):
104+
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
105+
for i in self.llm.inference_bistream(text=text,
106+
prompt_text=prompt_text.to(self.device),
107+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
108+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
109+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
110+
embedding=llm_embedding.to(self.device)):
111+
self.tts_speech_token_dict[uuid].append(i)
112+
else:
113+
for i in self.llm.inference(text=text.to(self.device),
114+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
115+
prompt_text=prompt_text.to(self.device),
116+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
117+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
118+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
119+
embedding=llm_embedding.to(self.device)):
120+
self.tts_speech_token_dict[uuid].append(i)
110121
self.llm_end_dict[uuid] = True
111122

112123
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):

cosyvoice/llm/llm.py

Lines changed: 107 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from cosyvoice.utils.common import IGNORE_ID
2121
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
2222
from cosyvoice.utils.common import th_accuracy
23+
from cosyvoice.utils.file_utils import logging
2324

2425

2526
class TransformerLM(torch.nn.Module):
@@ -144,10 +145,14 @@ def sampling_ids(
144145
sampling: int,
145146
ignore_eos: bool = True,
146147
):
148+
num_trials, max_trials = 0, 100
147149
while True:
148150
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
149151
if (not ignore_eos) or (self.speech_token_size not in top_ids):
150152
break
153+
num_trials += 1
154+
if num_trials > max_trials:
155+
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
151156
return top_ids
152157

153158
@torch.inference_mode()
@@ -239,7 +244,7 @@ def forward_one_step(self, xs, masks, cache=None):
239244
return xs, new_cache
240245

241246

242-
class Qwen2LM(torch.nn.Module):
247+
class Qwen2LM(TransformerLM):
243248
def __init__(
244249
self,
245250
llm_input_size: int,
@@ -249,8 +254,9 @@ def __init__(
249254
sampling: Callable,
250255
length_normalized_loss: bool = True,
251256
lsm_weight: float = 0.0,
257+
mix_ratio: List[int] = [5, 15],
252258
):
253-
super().__init__()
259+
torch.nn.Module.__init__(self)
254260
self.llm_input_size = llm_input_size
255261
self.llm_output_size = llm_output_size
256262
self.speech_token_size = speech_token_size
@@ -275,23 +281,7 @@ def __init__(
275281

276282
# 4. sampling method
277283
self.sampling = sampling
278-
279-
def sampling_ids(
280-
self,
281-
weighted_scores: torch.Tensor,
282-
decoded_tokens: List,
283-
sampling: int,
284-
ignore_eos: bool = True,
285-
):
286-
num_trials, max_trials = 0, 100
287-
while True:
288-
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
289-
if (not ignore_eos) or (self.speech_token_size not in top_ids):
290-
break
291-
num_trials += 1
292-
if num_trials > max_trials:
293-
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
294-
return top_ids
284+
self.mix_ratio = mix_ratio
295285

296286
@torch.inference_mode()
297287
def inference(
@@ -312,17 +302,14 @@ def inference(
312302
text_len += prompt_text_len
313303
text = self.llm.model.model.embed_tokens(text)
314304

315-
# 2. encode embedding
316-
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
317-
318305
# 3. concat llm_input
319306
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
320307
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
321308
if prompt_speech_token_len != 0:
322309
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
323310
else:
324311
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
325-
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
312+
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
326313

327314
# 4. cal min/max_length
328315
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -345,3 +332,100 @@ def inference(
345332
yield top_ids
346333
out_tokens.append(top_ids)
347334
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
335+
336+
@torch.inference_mode()
337+
def inference_bistream(
338+
self,
339+
text: Generator,
340+
prompt_text: torch.Tensor,
341+
prompt_text_len: torch.Tensor,
342+
prompt_speech_token: torch.Tensor,
343+
prompt_speech_token_len: torch.Tensor,
344+
embedding: torch.Tensor,
345+
sampling: int = 25,
346+
max_token_text_ratio: float = 20,
347+
min_token_text_ratio: float = 2,
348+
) -> Generator[torch.Tensor, None, None]:
349+
350+
device = prompt_text.device
351+
# 1. prepare input
352+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
353+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
354+
if prompt_speech_token_len != 0:
355+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
356+
else:
357+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
358+
lm_input = torch.concat([sos_eos_emb], dim=1)
359+
360+
# 2. iterate text
361+
out_tokens = []
362+
cache = None
363+
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
364+
text_cache = self.llm.model.model.embed_tokens(prompt_text)
365+
next_fill_index = -1
366+
for this_text in text:
367+
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
368+
# prompt_speech_token_emb not empty, try append to lm_input
369+
while prompt_speech_token_emb.size(1) != 0:
370+
if text_cache.size(1) >= self.mix_ratio[0]:
371+
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
372+
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
373+
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
374+
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
375+
else:
376+
logging.info('not enough text token to decode, wait for more')
377+
break
378+
# no prompt_speech_token_emb remain, can decode some speech token
379+
if prompt_speech_token_emb.size(1) == 0:
380+
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
381+
logging.info('get fill token, need to append more text token')
382+
if text_cache.size(1) >= self.mix_ratio[0]:
383+
lm_input_text = text_cache[:, :self.mix_ratio[0]]
384+
logging.info('append {} text token'.format(lm_input_text.size(1)))
385+
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
386+
text_cache = text_cache[:, self.mix_ratio[0]:]
387+
else:
388+
logging.info('not enough text token to decode, wait for more')
389+
continue
390+
while True:
391+
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
392+
y_pred, cache = self.llm.forward_one_step(lm_input,
393+
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
394+
cache=cache)
395+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
396+
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
397+
top_ids = self.speech_token_size + 2
398+
next_fill_index += (self.mix_ratio[1] + 1)
399+
else:
400+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
401+
if top_ids == self.speech_token_size + 2:
402+
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
403+
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
404+
out_tokens.append(top_ids)
405+
if top_ids >= self.speech_token_size:
406+
if top_ids == self.speech_token_size + 2:
407+
break
408+
else:
409+
raise ValueError('should not get token {}'.format(top_ids))
410+
yield top_ids
411+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
412+
413+
# 3. final decode
414+
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
415+
logging.info('no more text token, decode until met eos')
416+
while True:
417+
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
418+
y_pred, cache = self.llm.forward_one_step(lm_input,
419+
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
420+
cache=cache)
421+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
422+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
423+
out_tokens.append(top_ids)
424+
if top_ids >= self.speech_token_size:
425+
if top_ids == self.speech_token_size:
426+
break
427+
else:
428+
raise ValueError('should not get token {}'.format(top_ids))
429+
# in stream mode, yield token one by one
430+
yield top_ids
431+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)

cosyvoice/utils/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,5 +162,5 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
162162
# attention mask bias
163163
# NOTE(Mddct): torch.finfo jit issues
164164
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
165-
mask = (1.0 - mask) * torch.finfo(dtype).min
165+
mask = (1.0 - mask) * -1.0e+10
166166
return mask
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../cosyvoice

examples/libritts/cosyvoice2/tools

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../tools

tools/extract_speech_token.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
def single_job(utt):
27-
audio, sample_rate = torchaudio.load(utt2wav[utt])
27+
audio, sample_rate = torchaudio.load(utt2wav[utt], backend='soundfile')
2828
if sample_rate != 16000:
2929
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
3030
if audio.shape[1] / 16000 > 30:

0 commit comments

Comments
 (0)