File tree Expand file tree Collapse file tree 1 file changed +3
-1
lines changed Expand file tree Collapse file tree 1 file changed +3
-1
lines changed Original file line number Diff line number Diff line change @@ -3511,6 +3511,7 @@ def generate(
3511
3511
output_ids ,
3512
3512
audio_scales = audio_scales ,
3513
3513
).audio_values .squeeze (1 )
3514
+ output_lengths = [audio .shape [0 ] for audio in output_values ]
3514
3515
else :
3515
3516
output_values = []
3516
3517
for sample_id in range (batch_size ):
@@ -3522,13 +3523,14 @@ def generate(
3522
3523
output_values .append (sample .transpose (0 , 2 ))
3523
3524
else :
3524
3525
output_values .append (torch .zeros ((1 , 1 , 1 )).to (self .device ))
3525
- # TODO: we should keep track of output length as well. Not really straightforward tbh
3526
+ output_lengths = [ audio . shape [ 0 ] for audio in output_values ]
3526
3527
output_values = (
3527
3528
torch .nn .utils .rnn .pad_sequence (output_values , batch_first = True , padding_value = 0 )
3528
3529
.squeeze (- 1 )
3529
3530
.squeeze (- 1 )
3530
3531
)
3531
3532
if generation_config .return_dict_in_generate :
3533
+ outputs ["audios_length" ] = output_lengths
3532
3534
outputs .sequences = output_values
3533
3535
return outputs
3534
3536
else :
You can’t perform that action at this time.
0 commit comments