We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 88f6a8e + d3b1a8e commit 737d101Copy full SHA for 737d101
cosyvoice/llm/llm.py
@@ -280,10 +280,14 @@ def sampling_ids(
280
sampling: int,
281
ignore_eos: bool = True,
282
):
283
+ num_trials, max_trials = 0, 100
284
while True:
285
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
286
if (not ignore_eos) or (self.speech_token_size not in top_ids):
287
break
288
+ num_trials += 1
289
+ if num_trials > max_trials:
290
+ raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
291
return top_ids
292
293
@torch.inference_mode()
0 commit comments