Skip to content

Commit 1847b26

Browse files
committed
fix: indextts long text infer
1 parent 45f8b01 commit 1847b26

File tree

1 file changed

+6
-5
lines changed
  • modules/repos_static/index_tts/indextts

1 file changed

+6
-5
lines changed

modules/repos_static/index_tts/indextts/infer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def bucket_sentences(self, sentences, bucket_max_size=4) -> List[List[Dict]]:
193193
outputs: List[Dict] = []
194194
for idx, sent in enumerate(sentences):
195195
outputs.append({"idx": idx, "sent": sent, "len": len(sent)})
196-
196+
197197
if len(outputs) > bucket_max_size:
198198
# split sentences into buckets by sentence length
199199
buckets: List[List[Dict]] = []
@@ -247,7 +247,9 @@ def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor:
247247
# 1.5版本以上,直接使用stop_text_token 右侧填充,填充到最大长度
248248
# [1, N] -> [N,]
249249
tokens = [t.squeeze(0) for t in tokens]
250-
return pad_sequence(tokens, batch_first=True, padding_value=self.cfg.gpt.stop_text_token, padding_side="right")
250+
return pad_sequence(
251+
tokens, batch_first=True, padding_value=self.cfg.gpt.stop_text_token
252+
)
251253
max_len = max(t.size(1) for t in tokens)
252254
outputs = []
253255
for tensor in tokens:
@@ -287,7 +289,7 @@ def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_to
287289
"""
288290
if verbose:
289291
print(">> start fast inference...")
290-
292+
291293
self._set_gr_progress(0, "start fast inference...")
292294
if verbose:
293295
print(f"origin text:{text}")
@@ -365,8 +367,7 @@ def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_to
365367
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
366368
print("text_token_syms is same as sentence tokens", text_token_syms == sent)
367369
temp_tokens.append(text_tokens)
368-
369-
370+
370371
# Sequential processing of bucketing data
371372
all_batch_num = sum(len(s) for s in all_sentences)
372373
all_batch_codes = []

0 commit comments

Comments
 (0)