@@ -1881,6 +1881,7 @@ def forward(
1881
1881
output_hidden_states : Optional [bool ] = None ,
1882
1882
return_dict : Optional [bool ] = None ,
1883
1883
cache_position : Optional [torch .LongTensor ] = None ,
1884
+ loss_reduction : str = "mean" ,
1884
1885
) -> Union [Tuple , ParlerTTSCausalLMOutputWithCrossAttentions ]:
1885
1886
r"""
1886
1887
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
@@ -1925,7 +1926,7 @@ def forward(
1925
1926
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
1926
1927
logits = lm_logits [:, :, - labels .shape [1 ] :]
1927
1928
1928
- loss_fct = CrossEntropyLoss ()
1929
+ loss_fct = CrossEntropyLoss (reduction = loss_reduction )
1929
1930
loss = torch .zeros ([], device = self .device )
1930
1931
1931
1932
per_codebook_losses = []
@@ -2713,6 +2714,7 @@ def forward(
2713
2714
output_hidden_states : Optional [bool ] = None ,
2714
2715
return_dict : Optional [bool ] = None ,
2715
2716
cache_position : Optional [torch .LongTensor ] = None ,
2717
+ loss_reduction : str = "mean" ,
2716
2718
** kwargs ,
2717
2719
) -> Union [Tuple , ParlerTTSSeq2SeqLMOutput ]:
2718
2720
r"""
@@ -2857,6 +2859,7 @@ def forward(
2857
2859
return_dict = return_dict ,
2858
2860
labels = labels ,
2859
2861
cache_position = cache_position ,
2862
+ loss_reduction = loss_reduction ,
2860
2863
** kwargs_decoder ,
2861
2864
)
2862
2865
0 commit comments