@@ -74,7 +74,7 @@ def update_core(self, batch):
7474 if spk_emb is not None :
7575 spk_id = None
7676
77- after_outs , before_outs , logits , ys , labels , olens , att_ws , olens_in = self .model (
77+ after_outs , before_outs , logits , ys , stop_labels , olens , att_ws , olens_in = self .model (
7878 text = batch ["text" ],
7979 text_lengths = batch ["text_lengths" ],
8080 speech = batch ["speech" ],
@@ -83,8 +83,13 @@ def update_core(self, batch):
8383 spk_emb = spk_emb )
8484
8585 # calculate taco2 loss
86- l1_loss , mse_loss , bce_loss = self .taco2_loss (after_outs , before_outs ,
87- logits , ys , labels , olens )
86+ l1_loss , mse_loss , bce_loss = self .taco2_loss (
87+ after_outs = after_outs ,
88+ before_outs = before_outs ,
89+ logits = logits ,
90+ ys = ys ,
91+ stop_labels = stop_labels ,
92+ olens = olens )
8893
8994 if self .loss_type == "L1+L2" :
9095 loss = l1_loss + mse_loss + bce_loss
@@ -164,7 +169,7 @@ def evaluate_core(self, batch):
164169 if spk_emb is not None :
165170 spk_id = None
166171
167- after_outs , before_outs , logits , ys , labels , olens , att_ws , olens_in = self .model (
172+ after_outs , before_outs , logits , ys , stop_labels , olens , att_ws , olens_in = self .model (
168173 text = batch ["text" ],
169174 text_lengths = batch ["text_lengths" ],
170175 speech = batch ["speech" ],
@@ -173,8 +178,13 @@ def evaluate_core(self, batch):
173178 spk_emb = spk_emb )
174179
175180 # calculate taco2 loss
176- l1_loss , mse_loss , bce_loss = self .taco2_loss (after_outs , before_outs ,
177- logits , ys , labels , olens )
181+ l1_loss , mse_loss , bce_loss = self .taco2_loss (
182+ after_outs = after_outs ,
183+ before_outs = before_outs ,
184+ logits = logits ,
185+ ys = ys ,
186+ stop_labels = stop_labels ,
187+ olens = olens )
178188
179189 if self .loss_type == "L1+L2" :
180190 loss = l1_loss + mse_loss + bce_loss
0 commit comments