diff --git a/model.py b/model.py index a89defd..d3bfe04 100644 --- a/model.py +++ b/model.py @@ -896,7 +896,7 @@ def inference( align = ctc_forced_align( logits_speech.unsqueeze(0).float(), torch.Tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device), - (encoder_out_lens-4).long(), + (encoder_out_lens-4).long()[i], torch.tensor(len(token_int)-4).unsqueeze(0).long().to(logits_speech.device), ignore_id=self.ignore_id, )