2020from cosyvoice .utils .common import IGNORE_ID
2121from cosyvoice .transformer .label_smoothing_loss import LabelSmoothingLoss
2222from cosyvoice .utils .common import th_accuracy
23+ from cosyvoice .utils .file_utils import logging
2324
2425
2526class TransformerLM (torch .nn .Module ):
@@ -144,10 +145,14 @@ def sampling_ids(
144145 sampling : int ,
145146 ignore_eos : bool = True ,
146147 ):
148+ num_trials , max_trials = 0 , 100
147149 while True :
148150 top_ids = self .sampling (weighted_scores , decoded_tokens , sampling )
149151 if (not ignore_eos ) or (self .speech_token_size not in top_ids ):
150152 break
153+ num_trials += 1
154+ if num_trials > max_trials :
155+ raise RuntimeError ('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!' .format (max_trials ))
151156 return top_ids
152157
153158 @torch .inference_mode ()
@@ -239,7 +244,7 @@ def forward_one_step(self, xs, masks, cache=None):
239244 return xs , new_cache
240245
241246
242- class Qwen2LM (torch . nn . Module ):
247+ class Qwen2LM (TransformerLM ):
243248 def __init__ (
244249 self ,
245250 llm_input_size : int ,
@@ -249,8 +254,9 @@ def __init__(
249254 sampling : Callable ,
250255 length_normalized_loss : bool = True ,
251256 lsm_weight : float = 0.0 ,
257+ mix_ratio : List [int ] = [5 , 15 ],
252258 ):
253- super (). __init__ ()
259+ torch . nn . Module . __init__ (self )
254260 self .llm_input_size = llm_input_size
255261 self .llm_output_size = llm_output_size
256262 self .speech_token_size = speech_token_size
@@ -275,23 +281,7 @@ def __init__(
275281
276282 # 4. sampling method
277283 self .sampling = sampling
278-
279- def sampling_ids (
280- self ,
281- weighted_scores : torch .Tensor ,
282- decoded_tokens : List ,
283- sampling : int ,
284- ignore_eos : bool = True ,
285- ):
286- num_trials , max_trials = 0 , 100
287- while True :
288- top_ids = self .sampling (weighted_scores , decoded_tokens , sampling )
289- if (not ignore_eos ) or (self .speech_token_size not in top_ids ):
290- break
291- num_trials += 1
292- if num_trials > max_trials :
293- raise RuntimeError ('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!' .format (max_trials ))
294- return top_ids
284+ self .mix_ratio = mix_ratio
295285
296286 @torch .inference_mode ()
297287 def inference (
@@ -312,17 +302,14 @@ def inference(
312302 text_len += prompt_text_len
313303 text = self .llm .model .model .embed_tokens (text )
314304
315- # 2. encode embedding
316- embedding = torch .zeros (1 , 0 , self .llm_input_size , dtype = text .dtype ).to (device ).to (text .dtype )
317-
318305 # 3. concat llm_input
319306 sos_eos_emb = self .llm_embedding .weight [self .sos_eos ].reshape (1 , 1 , - 1 )
320307 task_id_emb = self .llm_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
321308 if prompt_speech_token_len != 0 :
322309 prompt_speech_token_emb = self .speech_embedding (prompt_speech_token )
323310 else :
324311 prompt_speech_token_emb = torch .zeros (1 , 0 , self .llm_input_size , dtype = text .dtype ).to (device )
325- lm_input = torch .concat ([sos_eos_emb , embedding , text , task_id_emb , prompt_speech_token_emb ], dim = 1 )
312+ lm_input = torch .concat ([sos_eos_emb , text , task_id_emb , prompt_speech_token_emb ], dim = 1 )
326313
327314 # 4. cal min/max_length
328315 min_len = int ((text_len - prompt_text_len ) * min_token_text_ratio )
@@ -345,3 +332,103 @@ def inference(
345332 yield top_ids
346333 out_tokens .append (top_ids )
347334 lm_input = self .speech_embedding .weight [top_ids ].reshape (1 , 1 , - 1 )
335+
336+ @torch .inference_mode ()
337+ def inference_bistream (
338+ self ,
339+ text : Generator ,
340+ prompt_text : torch .Tensor ,
341+ prompt_text_len : torch .Tensor ,
342+ prompt_speech_token : torch .Tensor ,
343+ prompt_speech_token_len : torch .Tensor ,
344+ embedding : torch .Tensor ,
345+ sampling : int = 25 ,
346+ max_token_text_ratio : float = 20 ,
347+ min_token_text_ratio : float = 2 ,
348+ ) -> Generator [torch .Tensor , None , None ]:
349+
350+ device = prompt_text .device
351+ # 1. prepare input
352+ sos_eos_emb = self .llm_embedding .weight [self .sos_eos ].reshape (1 , 1 , - 1 )
353+ task_id_emb = self .llm_embedding .weight [self .task_id ].reshape (1 , 1 , - 1 )
354+ if prompt_speech_token_len != 0 :
355+ prompt_speech_token_emb = self .speech_embedding (prompt_speech_token )
356+ else :
357+ prompt_speech_token_emb = torch .zeros (1 , 0 , self .llm_input_size , dtype = prompt_text .dtype ).to (device )
358+ lm_input = torch .concat ([sos_eos_emb ], dim = 1 )
359+
360+ # 2. iterate text
361+ out_tokens = []
362+ cache = None
363+ # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
364+ text_cache = self .llm .model .model .embed_tokens (prompt_text )
365+ next_fill_index = - 1
366+ for this_text in text :
367+ text_cache = torch .concat ([text_cache , self .llm .model .model .embed_tokens (this_text )], dim = 1 )
368+ # prompt_speech_token_emb not empty, try append to lm_input
369+ while prompt_speech_token_emb .size (1 ) != 0 :
370+ if text_cache .size (1 ) >= self .mix_ratio [0 ]:
371+ lm_input_text , lm_input_speech = text_cache [:, :self .mix_ratio [0 ]], prompt_speech_token_emb [:, :self .mix_ratio [1 ]]
372+ logging .info ('append {} text token {} speech token' .format (lm_input_text .size (1 ), lm_input_speech .size (1 )))
373+ lm_input = torch .concat ([lm_input , lm_input_text , lm_input_speech ], dim = 1 )
374+ text_cache , prompt_speech_token_emb = text_cache [:, self .mix_ratio [0 ]:], prompt_speech_token_emb [:, self .mix_ratio [1 ]:]
375+ else :
376+ logging .info ('not enough text token to decode, wait for more' )
377+ break
378+ # no prompt_speech_token_emb remain, can decode some speech token
379+ if prompt_speech_token_emb .size (1 ) == 0 :
380+ if (len (out_tokens ) != 0 and out_tokens [- 1 ] == self .speech_token_size + 2 ) or (len (out_tokens ) == 0 and lm_input .size (1 ) == 1 ):
381+ logging .info ('get fill token, need to append more text token' )
382+ if text_cache .size (1 ) >= self .mix_ratio [0 ]:
383+ lm_input_text = text_cache [:, :self .mix_ratio [0 ]]
384+ logging .info ('append {} text token' .format (lm_input_text .size (1 )))
385+ if len (out_tokens ) != 0 and out_tokens [- 1 ] == self .speech_token_size + 2 :
386+ lm_input = lm_input_text
387+ else :
388+ lm_input = torch .concat ([lm_input , lm_input_text ], dim = 1 )
389+ text_cache = text_cache [:, self .mix_ratio [0 ]:]
390+ else :
391+ logging .info ('not enough text token to decode, wait for more' )
392+ continue
393+ while True :
394+ seq_len = lm_input .shape [1 ] if cache is None else lm_input .shape [1 ] + cache [0 ][0 ].size (2 )
395+ y_pred , cache = self .llm .forward_one_step (lm_input ,
396+ masks = torch .tril (torch .ones ((1 , seq_len , seq_len ), device = lm_input .device )).to (torch .bool ),
397+ cache = cache )
398+ logp = self .llm_decoder (y_pred [:, - 1 ]).log_softmax (dim = - 1 )
399+ if next_fill_index != - 1 and len (out_tokens ) == next_fill_index :
400+ top_ids = self .speech_token_size + 2
401+ next_fill_index += (self .mix_ratio [1 ] + 1 )
402+ else :
403+ top_ids = self .sampling_ids (logp .squeeze (dim = 0 ), out_tokens , sampling , ignore_eos = True ).item ()
404+ if top_ids == self .speech_token_size + 2 :
405+ next_fill_index = len (out_tokens ) + self .mix_ratio [1 ] + 1
406+ logging .info ('fill_token index {} next fill_token index {}' .format (len (out_tokens ), next_fill_index ))
407+ out_tokens .append (top_ids )
408+ if top_ids >= self .speech_token_size :
409+ if top_ids == self .speech_token_size + 2 :
410+ break
411+ else :
412+ raise ValueError ('should not get token {}' .format (top_ids ))
413+ yield top_ids
414+ lm_input = self .speech_embedding .weight [top_ids ].reshape (1 , 1 , - 1 )
415+
416+ # 3. final decode
417+ lm_input = torch .concat ([lm_input , text_cache , task_id_emb ], dim = 1 )
418+ logging .info ('no more text token, decode until met eos' )
419+ while True :
420+ seq_len = lm_input .shape [1 ] if cache is None else lm_input .shape [1 ] + cache [0 ][0 ].size (2 )
421+ y_pred , cache = self .llm .forward_one_step (lm_input ,
422+ masks = torch .tril (torch .ones ((1 , seq_len , seq_len ), device = lm_input .device )).to (torch .bool ),
423+ cache = cache )
424+ logp = self .llm_decoder (y_pred [:, - 1 ]).log_softmax (dim = - 1 )
425+ top_ids = self .sampling_ids (logp .squeeze (dim = 0 ), out_tokens , sampling , ignore_eos = False ).item ()
426+ out_tokens .append (top_ids )
427+ if top_ids >= self .speech_token_size :
428+ if top_ids == self .speech_token_size :
429+ break
430+ else :
431+ raise ValueError ('should not get token {}' .format (top_ids ))
432+ # in stream mode, yield token one by one
433+ yield top_ids
434+ lm_input = self .speech_embedding .weight [top_ids ].reshape (1 , 1 , - 1 )
0 commit comments