Skip to content

Commit c76cb6d

Browse files
Alexei Baevskifacebook-github-bot
authored andcommitted
composite criterion should still use legacy criterion as it will break with subsequent diff
Summary: see title Reviewed By: myleott Differential Revision: D24393903 fbshipit-source-id: 4b972b8150c7228fb32977675c6c60b13d5194d0
1 parent de5c2cb commit c76cb6d

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

fairseq/criterions/composite_loss.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from fairseq import utils
7-
from fairseq.criterions import FairseqCriterion, register_criterion
7+
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
88
from torch import nn
99

1010

1111
@register_criterion("composite_loss")
12-
class CompositeLoss(FairseqCriterion):
12+
class CompositeLoss(LegacyFairseqCriterion):
1313
"""This is a composite loss that, given a list of model outputs and a list of targets,
1414
computes an average of losses for each output-target pair"""
1515

16-
def __init__(self, task, underlying_criterion):
17-
super().__init__(task)
18-
self.underlying_criterion = underlying_criterion
16+
def __init__(self, args, task):
17+
super().__init__(args, task)
18+
self.underlying_criterion = args.underlying_criterion
1919

2020
@staticmethod
2121
def add_args(parser):
@@ -60,9 +60,9 @@ def get_targets(self, *unused):
6060
def decoder(self):
6161
return self.model.decoder
6262

63-
class _CompositeLoss(FairseqCriterion):
64-
def __init__(self, task, underlying_criterion):
65-
super().__init__(task)
63+
class _CompositeLoss(LegacyFairseqCriterion):
64+
def __init__(self, args, task, underlying_criterion):
65+
super().__init__(args, task)
6666
self.underlying_criterion = underlying_criterion
6767

6868
def forward(self, model, sample, reduce=True):
@@ -97,4 +97,4 @@ def aggregate_logging_outputs(logging_outputs):
9797
def reduce_metrics(logging_outputs) -> None:
9898
underlying_criterion.__class__.reduce_metrics(logging_outputs)
9999

100-
return _CompositeLoss(task, underlying_criterion)
100+
return _CompositeLoss(args, task, underlying_criterion)

0 commit comments

Comments
 (0)