|
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | 6 | from fairseq import utils |
7 | | -from fairseq.criterions import FairseqCriterion, register_criterion |
| 7 | +from fairseq.criterions import LegacyFairseqCriterion, register_criterion |
8 | 8 | from torch import nn |
9 | 9 |
|
10 | 10 |
|
11 | 11 | @register_criterion("composite_loss") |
12 | | -class CompositeLoss(FairseqCriterion): |
| 12 | +class CompositeLoss(LegacyFairseqCriterion): |
13 | 13 | """This is a composite loss that, given a list of model outputs and a list of targets, |
14 | 14 | computes an average of losses for each output-target pair""" |
15 | 15 |
|
16 | | - def __init__(self, task, underlying_criterion): |
17 | | - super().__init__(task) |
18 | | - self.underlying_criterion = underlying_criterion |
| 16 | + def __init__(self, args, task): |
| 17 | + super().__init__(args, task) |
| 18 | + self.underlying_criterion = args.underlying_criterion |
19 | 19 |
|
20 | 20 | @staticmethod |
21 | 21 | def add_args(parser): |
@@ -60,9 +60,9 @@ def get_targets(self, *unused): |
60 | 60 | def decoder(self): |
61 | 61 | return self.model.decoder |
62 | 62 |
|
63 | | - class _CompositeLoss(FairseqCriterion): |
64 | | - def __init__(self, task, underlying_criterion): |
65 | | - super().__init__(task) |
| 63 | + class _CompositeLoss(LegacyFairseqCriterion): |
| 64 | + def __init__(self, args, task, underlying_criterion): |
| 65 | + super().__init__(args, task) |
66 | 66 | self.underlying_criterion = underlying_criterion |
67 | 67 |
|
68 | 68 | def forward(self, model, sample, reduce=True): |
@@ -97,4 +97,4 @@ def aggregate_logging_outputs(logging_outputs): |
97 | 97 | def reduce_metrics(logging_outputs) -> None: |
98 | 98 | underlying_criterion.__class__.reduce_metrics(logging_outputs) |
99 | 99 |
|
100 | | - return _CompositeLoss(task, underlying_criterion) |
| 100 | + return _CompositeLoss(args, task, underlying_criterion) |
0 commit comments