@@ -1881,6 +1881,7 @@ def forward(
18811881 output_hidden_states : Optional [bool ] = None ,
18821882 return_dict : Optional [bool ] = None ,
18831883 cache_position : Optional [torch .LongTensor ] = None ,
1884+ loss_reduction : str = "mean" ,
18841885 ) -> Union [Tuple , ParlerTTSCausalLMOutputWithCrossAttentions ]:
18851886 r"""
18861887 labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
@@ -1925,7 +1926,7 @@ def forward(
19251926 # since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
19261927 logits = lm_logits [:, :, - labels .shape [1 ] :]
19271928
1928- loss_fct = CrossEntropyLoss ()
1929+ loss_fct = CrossEntropyLoss (reduction = loss_reduction )
19291930 loss = torch .zeros ([], device = self .device )
19301931
19311932 per_codebook_losses = []
@@ -2713,6 +2714,7 @@ def forward(
27132714 output_hidden_states : Optional [bool ] = None ,
27142715 return_dict : Optional [bool ] = None ,
27152716 cache_position : Optional [torch .LongTensor ] = None ,
2717+ loss_reduction : str = "mean" ,
27162718 ** kwargs ,
27172719 ) -> Union [Tuple , ParlerTTSSeq2SeqLMOutput ]:
27182720 r"""
@@ -2857,6 +2859,7 @@ def forward(
28572859 return_dict = return_dict ,
28582860 labels = labels ,
28592861 cache_position = cache_position ,
2862+ loss_reduction = loss_reduction ,
28602863 ** kwargs_decoder ,
28612864 )
28622865
0 commit comments