diff --git a/export_meta.py b/export_meta.py index 920ca64..aa6a661 100644 --- a/export_meta.py +++ b/export_meta.py @@ -33,14 +33,14 @@ def export_forward( textnorm_query = self.embed(textnorm.to(speech.device)).unsqueeze(1) print(textnorm_query.shape, speech.shape) speech = torch.cat((textnorm_query, speech), dim=1) - speech_lengths += 1 + speech_lengths = speech_lengths + 1 event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( speech.size(0), 1, 1 ) input_query = torch.cat((language_query, event_emo_query), dim=1) speech = torch.cat((input_query, speech), dim=1) - speech_lengths += 3 + speech_lengths = speech_lengths + 3 encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) if isinstance(encoder_out, tuple):