Skip to content

Commit d108732

Browse files
authored
Fix gradient accumulation (#176)
Co-authored-by: [email protected] <Yoach Lacombe>
1 parent 3d1b82a commit d108732

File tree

2 files changed

+209
-171
lines changed

2 files changed

+209
-171
lines changed

parler_tts/modeling_parler_tts.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1881,6 +1881,7 @@ def forward(
18811881
output_hidden_states: Optional[bool] = None,
18821882
return_dict: Optional[bool] = None,
18831883
cache_position: Optional[torch.LongTensor] = None,
1884+
loss_reduction: str = "mean",
18841885
) -> Union[Tuple, ParlerTTSCausalLMOutputWithCrossAttentions]:
18851886
r"""
18861887
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
@@ -1925,7 +1926,7 @@ def forward(
19251926
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
19261927
logits = lm_logits[:, :, -labels.shape[1] :]
19271928

1928-
loss_fct = CrossEntropyLoss()
1929+
loss_fct = CrossEntropyLoss(reduction=loss_reduction)
19291930
loss = torch.zeros([], device=self.device)
19301931

19311932
per_codebook_losses = []
@@ -2713,6 +2714,7 @@ def forward(
27132714
output_hidden_states: Optional[bool] = None,
27142715
return_dict: Optional[bool] = None,
27152716
cache_position: Optional[torch.LongTensor] = None,
2717+
loss_reduction: str = "mean",
27162718
**kwargs,
27172719
) -> Union[Tuple, ParlerTTSSeq2SeqLMOutput]:
27182720
r"""
@@ -2857,6 +2859,7 @@ def forward(
28572859
return_dict=return_dict,
28582860
labels=labels,
28592861
cache_position=cache_position,
2862+
loss_reduction=loss_reduction,
28602863
**kwargs_decoder,
28612864
)
28622865

0 commit comments

Comments
 (0)