@@ -25,6 +25,11 @@ def build_loss_compute(model, tgt_field, opt, train=True):
2525
2626 padding_idx = tgt_field .vocab .stoi [tgt_field .pad_token ]
2727 unk_idx = tgt_field .vocab .stoi [tgt_field .unk_token ]
28+
29+ if opt .lambda_coverage != 0 :
30+ assert opt .coverage_attn , "--coverage_attn needs to be set in " \
31+ "order to use --lambda_coverage != 0"
32+
2833 if opt .copy_attn :
2934 criterion = onmt .modules .CopyGeneratorLoss (
3035 len (tgt_field .vocab ), opt .copy_attn_force ,
@@ -47,10 +52,12 @@ def build_loss_compute(model, tgt_field, opt, train=True):
4752 loss_gen = model .generator [0 ] if use_raw_logits else model .generator
4853 if opt .copy_attn :
4954 compute = onmt .modules .CopyGeneratorLossCompute (
50- criterion , loss_gen , tgt_field .vocab , opt .copy_loss_by_seqlength
55+ criterion , loss_gen , tgt_field .vocab , opt .copy_loss_by_seqlength ,
56+ lambda_coverage = opt .lambda_coverage
5157 )
5258 else :
53- compute = NMTLossCompute (criterion , loss_gen )
59+ compute = NMTLossCompute (
60+ criterion , loss_gen , lambda_coverage = opt .lambda_coverage )
5461 compute .to (device )
5562
5663 return compute
@@ -218,26 +225,53 @@ class NMTLossCompute(LossComputeBase):
218225 Standard NMT Loss Computation.
219226 """
220227
221- def __init__ (self , criterion , generator , normalization = "sents" ):
228+ def __init__ (self , criterion , generator , normalization = "sents" ,
229+ lambda_coverage = 0.0 ):
222230 super (NMTLossCompute , self ).__init__ (criterion , generator )
231+ self .lambda_coverage = lambda_coverage
223232
224233 def _make_shard_state (self , batch , output , range_ , attns = None ):
225- return {
234+ shard_state = {
226235 "output" : output ,
227236 "target" : batch .tgt [range_ [0 ] + 1 : range_ [1 ], :, 0 ],
228237 }
238+ if self .lambda_coverage != 0.0 :
239+ coverage = attns .get ("coverage" , None )
240+ std = attns .get ("std" , None )
241+ assert attns is not None
242+ assert std is not None , "lambda_coverage != 0.0 requires " \
243+ "attention mechanism"
244+ assert coverage is not None , "lambda_coverage != 0.0 requires " \
245+ "coverage attention"
246+
247+ shard_state .update ({
248+ "std_attn" : attns .get ("std" ),
249+ "coverage_attn" : coverage
250+ })
251+ return shard_state
252+
253+ def _compute_loss (self , batch , output , target , std_attn = None ,
254+ coverage_attn = None ):
229255
230- def _compute_loss (self , batch , output , target ):
231256 bottled_output = self ._bottle (output )
232257
233258 scores = self .generator (bottled_output )
234259 gtruth = target .view (- 1 )
235260
236261 loss = self .criterion (scores , gtruth )
262+ if self .lambda_coverage != 0.0 :
263+ coverage_loss = self ._compute_coverage_loss (
264+ std_attn = std_attn , coverage_attn = coverage_attn )
265+ loss += coverage_loss
237266 stats = self ._stats (loss .clone (), scores , gtruth )
238267
239268 return loss , stats
240269
270+ def _compute_coverage_loss (self , std_attn , coverage_attn ):
271+ covloss = torch .min (std_attn , coverage_attn ).sum (2 ).view (- 1 )
272+ covloss *= self .lambda_coverage
273+ return covloss
274+
241275
242276def filter_shard_state (state , shard_size = None ):
243277 for k , v in state .items ():
0 commit comments