Skip to content

Commit 0e62489

Browse files
committed
simply code
1 parent 2145b58 commit 0e62489

File tree

3 files changed

+62
-118
lines changed

3 files changed

+62
-118
lines changed

cosyvoice/cli/cosyvoice.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend
8989
start_time = time.time()
9090

9191
def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
92-
if self.__class__.__name__ == 'CosyVoice3' and '<|endofprompt|>' not in prompt_text + tts_text:
93-
logging.warning('<|endofprompt|> not found in CosyVoice3 inference, check your input text')
9492
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
9593
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
9694
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):

cosyvoice/cli/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uui
117117
prompt_speech_token=llm_prompt_speech_token.to(self.device),
118118
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
119119
embedding=llm_embedding.to(self.device),
120-
uuid=uuid)
120+
uuid=uuid)
121121
for i in token_generator:
122122
if i in self.silent_tokens:
123123
cur_silent_token_num += 1
@@ -256,7 +256,7 @@ def __init__(self,
256256
self.fp16 = fp16
257257
# NOTE must matching training static_chunk_size
258258
self.token_hop_len = 25
259-
# NOTE increase token_hop_len incrementally to avoid duplicate inference
259+
# NOTE increase token_hop_len incrementally to avoid duplicate inference
260260
self.token_max_hop_len = 4 * self.token_hop_len
261261
self.stream_scale_factor = 2
262262
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
@@ -408,7 +408,7 @@ def __init__(self,
408408
self.fp16 = fp16
409409
# NOTE must matching training static_chunk_size
410410
self.token_hop_len = 25
411-
# NOTE increase token_hop_len incrementally to avoid duplicate inference
411+
# NOTE increase token_hop_len incrementally to avoid duplicate inference
412412
self.token_max_hop_len = 4 * self.token_hop_len
413413
self.stream_scale_factor = 2
414414
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'

cosyvoice/llm/llm.py

Lines changed: 59 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,9 @@ def sampling_ids(
154154
sampling: int,
155155
ignore_eos: bool = True,
156156
):
157-
num_trials, max_trials = 0, 100
158-
while True:
159-
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
160-
if (not ignore_eos) or (top_ids < self.speech_token_size):
161-
break
162-
num_trials += 1
163-
if num_trials > max_trials:
164-
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
157+
if ignore_eos is True:
158+
weighted_scores[self.speech_token_size] = -float('inf')
159+
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
165160
return top_ids
166161

167162
@torch.inference_mode()
@@ -365,34 +360,48 @@ def forward(
365360
audio: (B, T, N) or (B, T)
366361
audio_lengths: (B,)
367362
"""
363+
# 1. encode text_token
368364
text_token = batch['text_token'].to(device)
369365
text_token_len = batch['text_token_len'].to(device)
366+
text_token_emb = self.llm.model.model.embed_tokens(text_token)
367+
368+
# 2. encode speech_token
370369
if 'speech_token' not in batch:
371370
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
372371
else:
373372
speech_token = batch['speech_token'].to(device)
374373
speech_token_len = batch['speech_token_len'].to(device)
375-
376-
# 1. encode text_token
377-
text_token_emb = self.llm.model.model.embed_tokens(text_token)
378-
379-
# 3. sos and task_id
380-
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
381-
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
382-
383-
# 2. encode speech_token
384374
speech_token_emb = self.speech_embedding(speech_token)
385375

386-
# 3. prepare llm_input/target
387-
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
388-
speech_token, speech_token_emb, speech_token_len)
376+
# 3. sos and task_id
377+
if self.__class__.__name__ == 'CosyVoice3LM':
378+
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
379+
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
380+
elif self.__class__.__name__ == 'Qwen2LM':
381+
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
382+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
383+
else:
384+
raise ValueError
385+
386+
# 4. prepare llm_input/target
387+
if self.__class__.__name__ == 'CosyVoice3LM':
388+
instruct_token = batch['instruct_token'].to(device)
389+
instruct_token_len = batch['instruct_token_len'].to(device)
390+
instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
391+
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
392+
speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
393+
elif self.__class__.__name__ == 'Qwen2LM':
394+
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
395+
speech_token, speech_token_emb, speech_token_len)
396+
else:
397+
raise ValueError
389398
lm_target = lm_target.to(device)
390399

391400
# 4. run lm forward
392401
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
393402
logits = self.llm_decoder(lm_output)
394403
loss = self.criterion_ce(logits, lm_target.to(device))
395-
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
404+
acc = th_accuracy(logits.view(-1, self.llm_decoder.out_features), lm_target, ignore_label=IGNORE_ID)
396405
return {'loss': loss, 'acc': acc}
397406

398407
def forward_dpo(
@@ -464,16 +473,25 @@ def inference(
464473
device = text.device
465474
text = torch.concat([prompt_text, text], dim=1)
466475
text_len += prompt_text_len
467-
text = self.llm.model.model.embed_tokens(text)
476+
text_emb = self.llm.model.model.embed_tokens(text)
477+
if self.__class__.__name__ == 'CosyVoice3LM':
478+
# NOTE temporary hardcode, 151646 is <|endofprompt|> token
479+
assert 151646 in text, '<|endofprompt|> not detected in CosyVoice3 text or prompt_text, check your input!'
468480

469481
# 3. concat llm_input
470-
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
471-
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
482+
if self.__class__.__name__ == 'CosyVoice3LM':
483+
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
484+
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
485+
elif self.__class__.__name__ == 'Qwen2LM':
486+
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
487+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
488+
else:
489+
raise ValueError
472490
if prompt_speech_token_len != 0:
473491
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
474492
else:
475-
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
476-
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
493+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text_emb.dtype).to(device)
494+
lm_input = torch.concat([sos_emb, text_emb, task_id_emb, prompt_speech_token_emb], dim=1)
477495

478496
# 4. cal min/max_length
479497
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -546,8 +564,14 @@ def inference_bistream(
546564

547565
device = prompt_text.device
548566
# 1. prepare input
549-
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
550-
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
567+
if self.__class__.__name__ == 'CosyVoice3LM':
568+
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
569+
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
570+
elif self.__class__.__name__ == 'Qwen2LM':
571+
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
572+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
573+
else:
574+
raise ValueError
551575
if prompt_speech_token_len != 0:
552576
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
553577
else:
@@ -558,6 +582,12 @@ def inference_bistream(
558582
out_tokens = []
559583
cache = None
560584
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
585+
if self.__class__.__name__ == 'CosyVoice3LM':
586+
# NOTE temporary hardcode, 151646 is <|endofprompt|> token
587+
assert 151646 in prompt_text, '<|endofprompt|> not detected in CosyVoice3 prompt_text, check your input!'
588+
eop_index = prompt_text.flatten().tolist().index(151646)
589+
lm_input = torch.concat([lm_input, self.llm.model.model.embed_tokens(prompt_text[:, :eop_index + 1])], dim=1)
590+
prompt_text = prompt_text[:, eop_index + 1:]
561591
text_cache = self.llm.model.model.embed_tokens(prompt_text)
562592
next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
563593
for this_text in text:
@@ -673,88 +703,4 @@ def __init__(
673703
self.stop_token_ids = [speech_token_size + i for i in range(200)]
674704
self.vllm_output_queue = {}
675705
if online_feature is True:
676-
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
677-
678-
def forward(
679-
self,
680-
batch: dict,
681-
device: torch.device,
682-
) -> Dict[str, Optional[torch.Tensor]]:
683-
"""
684-
Args:
685-
text: (B, L, D)
686-
text_lengths: (B,)
687-
audio: (B, T, N) or (B, T)
688-
audio_lengths: (B,)
689-
"""
690-
text_token = batch['text_token'].to(device)
691-
text_token_len = batch['text_token_len'].to(device)
692-
if 'speech_token' not in batch:
693-
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
694-
else:
695-
speech_token = batch['speech_token'].to(device)
696-
speech_token_len = batch['speech_token_len'].to(device)
697-
698-
# NOTE should append instruct_token to sequence, not implemented yet
699-
instruct_token = batch['instruct_token'].to(device)
700-
instruct_token_len = batch['instruct_token_len'].to(device)
701-
702-
# 1. encode text_token
703-
text_token_emb = self.llm.model.model.embed_tokens(text_token)
704-
instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
705-
706-
# 3. sos and task_id
707-
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
708-
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
709-
710-
# 2. encode speech_token
711-
speech_token_emb = self.speech_embedding(speech_token)
712-
713-
# 3. prepare llm_input/target
714-
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
715-
speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
716-
lm_target = lm_target.to(device)
717-
718-
# 4. run lm forward
719-
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
720-
logits = self.llm_decoder(lm_output)
721-
loss = self.criterion_ce(logits, lm_target.to(device))
722-
acc = th_accuracy(logits.view(-1, self.speech_token_size + 200), lm_target, ignore_label=IGNORE_ID)
723-
return {'loss': loss, 'acc': acc}
724-
725-
@torch.inference_mode()
726-
def inference(
727-
self,
728-
text: torch.Tensor,
729-
text_len: torch.Tensor,
730-
prompt_text: torch.Tensor,
731-
prompt_text_len: torch.Tensor,
732-
prompt_speech_token: torch.Tensor,
733-
prompt_speech_token_len: torch.Tensor,
734-
embedding: torch.Tensor,
735-
sampling: int = 25,
736-
max_token_text_ratio: float = 20,
737-
min_token_text_ratio: float = 2,
738-
uuid: str = '',
739-
) -> Generator[torch.Tensor, None, None]:
740-
device = text.device
741-
text = torch.concat([prompt_text, text], dim=1)
742-
text_len += prompt_text_len
743-
text = self.llm.model.model.embed_tokens(text)
744-
745-
# 3. concat llm_input
746-
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
747-
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
748-
if prompt_speech_token_len != 0:
749-
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
750-
else:
751-
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
752-
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
753-
754-
# 4. cal min/max_length
755-
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
756-
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
757-
758-
# 5. step by step decode
759-
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
760-
yield token
706+
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))

0 commit comments

Comments
 (0)