@@ -154,14 +154,9 @@ def sampling_ids(
154154 sampling : int ,
155155 ignore_eos : bool = True ,
156156 ):
157- num_trials , max_trials = 0 , 100
158- while True :
159- top_ids = self .sampling (weighted_scores , decoded_tokens , sampling )
160- if (not ignore_eos ) or (top_ids < self .speech_token_size ):
161- break
162- num_trials += 1
163- if num_trials > max_trials :
164- raise RuntimeError ('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!' .format (max_trials ))
157+ if ignore_eos is True :
158+ weighted_scores [self .speech_token_size ] = - float ('inf' )
159+ top_ids = self .sampling (weighted_scores , decoded_tokens , sampling )
165160 return top_ids
166161
167162 @torch .inference_mode ()
@@ -365,34 +360,48 @@ def forward(
365360 audio: (B, T, N) or (B, T)
366361 audio_lengths: (B,)
367362 """
363+ # 1. encode text_token
368364 text_token = batch ['text_token' ].to (device )
369365 text_token_len = batch ['text_token_len' ].to (device )
366+ text_token_emb = self .llm .model .model .embed_tokens (text_token )
367+
368+ # 2. encode speech_token
370369 if 'speech_token' not in batch :
371370 speech_token , speech_token_len = self .speech_token_extractor .inference (batch ['whisper_feat' ], batch ['whisper_feat_len' ], device )
372371 else :
373372 speech_token = batch ['speech_token' ].to (device )
374373 speech_token_len = batch ['speech_token_len' ].to (device )
375-
376- # 1. encode text_token
377- text_token_emb = self .llm .model .model .embed_tokens (text_token )
378-
379- # 3. sos and task_id
380- sos_emb = self .llm_embedding .weight [self .sos ].reshape (1 , 1 , - 1 )
381- task_id_emb = self .llm_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
382-
383- # 2. encode speech_token
384374 speech_token_emb = self .speech_embedding (speech_token )
385375
386- # 3. prepare llm_input/target
387- lm_target , lm_input , lm_input_len = self .prepare_lm_input_target (sos_emb , text_token , text_token_emb , text_token_len , task_id_emb ,
388- speech_token , speech_token_emb , speech_token_len )
376+ # 3. sos and task_id
377+ if self .__class__ .__name__ == 'CosyVoice3LM' :
378+ sos_emb = self .speech_embedding .weight [self .sos ].reshape (1 , 1 , - 1 )
379+ task_id_emb = self .speech_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
380+ elif self .__class__ .__name__ == 'Qwen2LM' :
381+ sos_emb = self .llm_embedding .weight [self .sos ].reshape (1 , 1 , - 1 )
382+ task_id_emb = self .llm_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
383+ else :
384+ raise ValueError
385+
386+ # 4. prepare llm_input/target
387+ if self .__class__ .__name__ == 'CosyVoice3LM' :
388+ instruct_token = batch ['instruct_token' ].to (device )
389+ instruct_token_len = batch ['instruct_token_len' ].to (device )
390+ instruct_token_emb = self .llm .model .model .embed_tokens (instruct_token )
391+ lm_target , lm_input , lm_input_len = self .prepare_lm_input_target (sos_emb , text_token , text_token_emb , text_token_len , task_id_emb ,
392+ speech_token , speech_token_emb , speech_token_len , instruct_token , instruct_token_emb , instruct_token_len )
393+ elif self .__class__ .__name__ == 'Qwen2LM' :
394+ lm_target , lm_input , lm_input_len = self .prepare_lm_input_target (sos_emb , text_token , text_token_emb , text_token_len , task_id_emb ,
395+ speech_token , speech_token_emb , speech_token_len )
396+ else :
397+ raise ValueError
389398 lm_target = lm_target .to (device )
390399
391400 # 4. run lm forward
392401 lm_output , lm_output_mask = self .llm (lm_input , lm_input_len .to (device ))
393402 logits = self .llm_decoder (lm_output )
394403 loss = self .criterion_ce (logits , lm_target .to (device ))
395- acc = th_accuracy (logits .view (- 1 , self .speech_token_size + 3 ), lm_target , ignore_label = IGNORE_ID )
404+ acc = th_accuracy (logits .view (- 1 , self .llm_decoder . out_features ), lm_target , ignore_label = IGNORE_ID )
396405 return {'loss' : loss , 'acc' : acc }
397406
398407 def forward_dpo (
@@ -464,16 +473,25 @@ def inference(
464473 device = text .device
465474 text = torch .concat ([prompt_text , text ], dim = 1 )
466475 text_len += prompt_text_len
467- text = self .llm .model .model .embed_tokens (text )
476+ text_emb = self .llm .model .model .embed_tokens (text )
477+ if self .__class__ .__name__ == 'CosyVoice3LM' :
478+ # NOTE temporary hardcode, 151646 is <|endofprompt|> token
479+ assert 151646 in text , '<|endofprompt|> not detected in CosyVoice3 text or prompt_text, check your input!'
468480
469481 # 3. concat llm_input
470- sos_emb = self .llm_embedding .weight [self .sos ].reshape (1 , 1 , - 1 )
471- task_id_emb = self .llm_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
482+ if self .__class__ .__name__ == 'CosyVoice3LM' :
483+ sos_emb = self .speech_embedding .weight [self .sos ].reshape (1 , 1 , - 1 )
484+ task_id_emb = self .speech_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
485+ elif self .__class__ .__name__ == 'Qwen2LM' :
486+ sos_emb = self .llm_embedding .weight [self .sos ].reshape (1 , 1 , - 1 )
487+ task_id_emb = self .llm_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
488+ else :
489+ raise ValueError
472490 if prompt_speech_token_len != 0 :
473491 prompt_speech_token_emb = self .speech_embedding (prompt_speech_token )
474492 else :
475- prompt_speech_token_emb = torch .zeros (1 , 0 , self .llm_input_size , dtype = text .dtype ).to (device )
476- lm_input = torch .concat ([sos_emb , text , task_id_emb , prompt_speech_token_emb ], dim = 1 )
493+ prompt_speech_token_emb = torch .zeros (1 , 0 , self .llm_input_size , dtype = text_emb .dtype ).to (device )
494+ lm_input = torch .concat ([sos_emb , text_emb , task_id_emb , prompt_speech_token_emb ], dim = 1 )
477495
478496 # 4. cal min/max_length
479497 min_len = int ((text_len - prompt_text_len ) * min_token_text_ratio )
@@ -546,8 +564,14 @@ def inference_bistream(
546564
547565 device = prompt_text .device
548566 # 1. prepare input
549- sos_emb = self .llm_embedding .weight [self .sos ].reshape (1 , 1 , - 1 )
550- task_id_emb = self .llm_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
567+ if self .__class__ .__name__ == 'CosyVoice3LM' :
568+ sos_emb = self .speech_embedding .weight [self .sos ].reshape (1 , 1 , - 1 )
569+ task_id_emb = self .speech_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
570+ elif self .__class__ .__name__ == 'Qwen2LM' :
571+ sos_emb = self .llm_embedding .weight [self .sos ].reshape (1 , 1 , - 1 )
572+ task_id_emb = self .llm_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
573+ else :
574+ raise ValueError
551575 if prompt_speech_token_len != 0 :
552576 prompt_speech_token_emb = self .speech_embedding (prompt_speech_token )
553577 else :
@@ -558,6 +582,12 @@ def inference_bistream(
558582 out_tokens = []
559583 cache = None
560584 # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
585+ if self .__class__ .__name__ == 'CosyVoice3LM' :
586+ # NOTE temporary hardcode, 151646 is <|endofprompt|> token
587+ assert 151646 in prompt_text , '<|endofprompt|> not detected in CosyVoice3 prompt_text, check your input!'
588+ eop_index = prompt_text .flatten ().tolist ().index (151646 )
589+ lm_input = torch .concat ([lm_input , self .llm .model .model .embed_tokens (prompt_text [:, :eop_index + 1 ])], dim = 1 )
590+ prompt_text = prompt_text [:, eop_index + 1 :]
561591 text_cache = self .llm .model .model .embed_tokens (prompt_text )
562592 next_fill_index = (int (prompt_speech_token .shape [1 ] / self .mix_ratio [1 ]) + 1 ) * self .mix_ratio [1 ] - prompt_speech_token .shape [1 ]
563593 for this_text in text :
@@ -673,88 +703,4 @@ def __init__(
673703 self .stop_token_ids = [speech_token_size + i for i in range (200 )]
674704 self .vllm_output_queue = {}
675705 if online_feature is True :
676- self .speech_token_extractor = SpeechTokenExtractor (model_path = os .path .join (onnx_path , 'speech_tokenizer_v3.batch.onnx' ))
677-
678- def forward (
679- self ,
680- batch : dict ,
681- device : torch .device ,
682- ) -> Dict [str , Optional [torch .Tensor ]]:
683- """
684- Args:
685- text: (B, L, D)
686- text_lengths: (B,)
687- audio: (B, T, N) or (B, T)
688- audio_lengths: (B,)
689- """
690- text_token = batch ['text_token' ].to (device )
691- text_token_len = batch ['text_token_len' ].to (device )
692- if 'speech_token' not in batch :
693- speech_token , speech_token_len = self .speech_token_extractor .inference (batch ['whisper_feat' ], batch ['whisper_feat_len' ], device )
694- else :
695- speech_token = batch ['speech_token' ].to (device )
696- speech_token_len = batch ['speech_token_len' ].to (device )
697-
698- # NOTE should append instruct_token to sequence, not implemented yet
699- instruct_token = batch ['instruct_token' ].to (device )
700- instruct_token_len = batch ['instruct_token_len' ].to (device )
701-
702- # 1. encode text_token
703- text_token_emb = self .llm .model .model .embed_tokens (text_token )
704- instruct_token_emb = self .llm .model .model .embed_tokens (instruct_token )
705-
706- # 3. sos and task_id
707- sos_emb = self .speech_embedding .weight [self .sos ].reshape (1 , 1 , - 1 )
708- task_id_emb = self .speech_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
709-
710- # 2. encode speech_token
711- speech_token_emb = self .speech_embedding (speech_token )
712-
713- # 3. prepare llm_input/target
714- lm_target , lm_input , lm_input_len = self .prepare_lm_input_target (sos_emb , text_token , text_token_emb , text_token_len , task_id_emb ,
715- speech_token , speech_token_emb , speech_token_len , instruct_token , instruct_token_emb , instruct_token_len )
716- lm_target = lm_target .to (device )
717-
718- # 4. run lm forward
719- lm_output , lm_output_mask = self .llm (lm_input , lm_input_len .to (device ))
720- logits = self .llm_decoder (lm_output )
721- loss = self .criterion_ce (logits , lm_target .to (device ))
722- acc = th_accuracy (logits .view (- 1 , self .speech_token_size + 200 ), lm_target , ignore_label = IGNORE_ID )
723- return {'loss' : loss , 'acc' : acc }
724-
725- @torch .inference_mode ()
726- def inference (
727- self ,
728- text : torch .Tensor ,
729- text_len : torch .Tensor ,
730- prompt_text : torch .Tensor ,
731- prompt_text_len : torch .Tensor ,
732- prompt_speech_token : torch .Tensor ,
733- prompt_speech_token_len : torch .Tensor ,
734- embedding : torch .Tensor ,
735- sampling : int = 25 ,
736- max_token_text_ratio : float = 20 ,
737- min_token_text_ratio : float = 2 ,
738- uuid : str = '' ,
739- ) -> Generator [torch .Tensor , None , None ]:
740- device = text .device
741- text = torch .concat ([prompt_text , text ], dim = 1 )
742- text_len += prompt_text_len
743- text = self .llm .model .model .embed_tokens (text )
744-
745- # 3. concat llm_input
746- sos_emb = self .speech_embedding .weight [self .sos ].reshape (1 , 1 , - 1 )
747- task_id_emb = self .speech_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
748- if prompt_speech_token_len != 0 :
749- prompt_speech_token_emb = self .speech_embedding (prompt_speech_token )
750- else :
751- prompt_speech_token_emb = torch .zeros (1 , 0 , self .llm_input_size , dtype = text .dtype ).to (device )
752- lm_input = torch .concat ([sos_emb , text , task_id_emb , prompt_speech_token_emb ], dim = 1 )
753-
754- # 4. cal min/max_length
755- min_len = int ((text_len - prompt_text_len ) * min_token_text_ratio )
756- max_len = int ((text_len - prompt_text_len ) * max_token_text_ratio )
757-
758- # 5. step by step decode
759- for token in self .inference_wrapper (lm_input , sampling , min_len , max_len , uuid ):
760- yield token
706+ self .speech_token_extractor = SpeechTokenExtractor (model_path = os .path .join (onnx_path , 'speech_tokenizer_v3.batch.onnx' ))
0 commit comments