Skip to content

Commit 1551b7c

Browse files
authored
add possibility to have audio_output_lengths (#91)
1 parent 862f841 commit 1551b7c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

parler_tts/modeling_parler_tts.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3511,6 +3511,7 @@ def generate(
35113511
output_ids,
35123512
audio_scales=audio_scales,
35133513
).audio_values.squeeze(1)
3514+
output_lengths = [audio.shape[0] for audio in output_values]
35143515
else:
35153516
output_values = []
35163517
for sample_id in range(batch_size):
@@ -3522,13 +3523,14 @@ def generate(
35223523
output_values.append(sample.transpose(0, 2))
35233524
else:
35243525
output_values.append(torch.zeros((1, 1, 1)).to(self.device))
3525-
# TODO: we should keep track of output length as well. Not really straightforward tbh
3526+
output_lengths = [audio.shape[0] for audio in output_values]
35263527
output_values = (
35273528
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
35283529
.squeeze(-1)
35293530
.squeeze(-1)
35303531
)
35313532
if generation_config.return_dict_in_generate:
3533+
outputs["audios_length"] = output_lengths
35323534
outputs.sequences = output_values
35333535
return outputs
35343536
else:

0 commit comments

Comments
 (0)