diff --git a/model.py b/model.py index 1dd42180..d78e0a65 100644 --- a/model.py +++ b/model.py @@ -172,10 +172,16 @@ def forward( lm_logits = self.lm_head(sequence_output) loss = None + z_loss = 0.0 if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + # Z-loss + if z_loss != 0.0: + log_z = torch.log(lm_logits) + loss += z_loss * torch.pow(log_z, 2.0) if not return_dict: output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs diff --git a/requirements.txt b/requirements.txt index 9ef5abc3..85d39a65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,5 @@ sentence-transformers==2.2.2 transformers==4.21.1 nltk==3.6.6 evaluate==0.4.0 -rouge==1.0.1 rouge_score==0.1.2 +rich==13.3.2 \ No newline at end of file