@@ -317,13 +317,21 @@ def validate(self, valid_iter, moving_average=None):
317317 tgt = batch .tgt
318318
319319 # F-prop through the model.
320- outputs , attns , enc_src , enc_tgt = valid_model (
321- src , tgt , src_lengths ,
322- with_align = self .with_align )
320+ if self .encode_tgt :
321+ outputs , attns , enc_src , enc_tgt = valid_model (
322+ src , tgt , src_lengths ,
323+ with_align = self .with_align ,
324+ encode_tgt = self .encode_tgt )
325+ else :
326+ output , attns = valid_model (
327+ src , tgt , src_lengths ,
328+ with_align = self .with_align )
329+ enc_src , enc_tgt = None , None
323330
324331 # Compute loss.
325332 _ , batch_stats = self .valid_loss (
326- batch , outputs , attns , enc_src , enc_tgt )
333+ batch , outputs , attns ,
334+ enc_src = enc_src , enc_tgt = enc_tgt )
327335
328336 # Update statistics.
329337 stats .update (batch_stats )
@@ -366,9 +374,16 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
366374 if self .accum_count == 1 :
367375 self .optim .zero_grad ()
368376
369- outputs , attns , enc_src , enc_tgt = self .model (
370- src , tgt , src_lengths , bptt = bptt ,
371- with_align = self .with_align , encode_tgt = self .encode_tgt )
377+ is self .encode_tgt :
378+ outputs , attns , enc_src , enc_tgt = self .model (
379+ src , tgt , src_lengths , bptt = bptt ,
380+ with_align = self .with_align , encode_tgt = self .encode_tgt )
381+ else :
382+ output , attns = self .model (
383+ src , tgt , src_lengths , bptt = bptt ,
384+ with_align = self .with_align )
385+ enc_src , enc_tgt = None , None
386+
372387 bptt = True
373388
374389 # 3. Compute loss.
@@ -377,8 +392,8 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
377392 batch ,
378393 outputs ,
379394 attns ,
380- enc_src ,
381- enc_tgt ,
395+ enc_src = enc_src ,
396+ enc_tgt = enc_tgt ,
382397 normalization = normalization ,
383398 shard_size = self .shard_size ,
384399 trunc_start = j ,
0 commit comments