@@ -193,7 +193,7 @@ def bucket_sentences(self, sentences, bucket_max_size=4) -> List[List[Dict]]:
193193 outputs : List [Dict ] = []
194194 for idx , sent in enumerate (sentences ):
195195 outputs .append ({"idx" : idx , "sent" : sent , "len" : len (sent )})
196-
196+
197197 if len (outputs ) > bucket_max_size :
198198 # split sentences into buckets by sentence length
199199 buckets : List [List [Dict ]] = []
@@ -247,7 +247,9 @@ def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor:
247247 # 1.5版本以上,直接使用stop_text_token 右侧填充,填充到最大长度
248248 # [1, N] -> [N,]
249249 tokens = [t .squeeze (0 ) for t in tokens ]
250- return pad_sequence (tokens , batch_first = True , padding_value = self .cfg .gpt .stop_text_token , padding_side = "right" )
250+ return pad_sequence (
251+ tokens , batch_first = True , padding_value = self .cfg .gpt .stop_text_token
252+ )
251253 max_len = max (t .size (1 ) for t in tokens )
252254 outputs = []
253255 for tensor in tokens :
@@ -287,7 +289,7 @@ def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_to
287289 """
288290 if verbose :
289291 print (">> start fast inference..." )
290-
292+
291293 self ._set_gr_progress (0 , "start fast inference..." )
292294 if verbose :
293295 print (f"origin text:{ text } " )
@@ -365,8 +367,7 @@ def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_to
365367 text_token_syms = self .tokenizer .convert_ids_to_tokens (text_tokens [0 ].tolist ())
366368 print ("text_token_syms is same as sentence tokens" , text_token_syms == sent )
367369 temp_tokens .append (text_tokens )
368-
369-
370+
370371 # Sequential processing of bucketing data
371372 all_batch_num = sum (len (s ) for s in all_sentences )
372373 all_batch_codes = []
0 commit comments