@@ -42,6 +42,11 @@ def __init__(
42
42
self .audio_encoder = model .audio_encoder
43
43
self .generation_config = model .generation_config
44
44
self .device = device if device is not None else model .device
45
+ self .use_audio_scales = model .use_audio_scales
46
+ self .use_4dim_audio_codes = model .use_4dim_audio_codes
47
+ self .audio_kwargs = {}
48
+ if self .use_audio_scales :
49
+ self .audio_kwargs ["audio_scales" ] = [None ]
45
50
46
51
# variables used in the streaming process
47
52
self .play_steps = play_steps
@@ -72,8 +77,10 @@ def apply_delay_pattern_mask(self, input_ids):
72
77
# revert the pattern delay mask by filtering the pad token id
73
78
mask = (delay_pattern_mask != self .generation_config .bos_token_id ) & (delay_pattern_mask != self .generation_config .pad_token_id )
74
79
input_ids = input_ids [mask ].reshape (1 , self .decoder .num_codebooks , - 1 )
75
- # append the frame dimension back to the audio codes
76
- input_ids = input_ids [None , ...]
80
+
81
+ if self .use_4dim_audio_codes :
82
+ # append the frame dimension back to the audio codes
83
+ input_ids = input_ids [None , ...]
77
84
78
85
# send the input_ids to the correct device
79
86
input_ids = input_ids .to (self .audio_encoder .device )
@@ -84,17 +91,19 @@ def apply_delay_pattern_mask(self, input_ids):
84
91
or self .generation_config .eos_token_id in input_ids
85
92
)
86
93
if not decode_sequentially :
87
- output_values = self .audio_encoder .decode (
88
- input_ids ,
89
- audio_scales = [None ],
90
- )
94
+ sample = self .audio_encoder .decode (
95
+ audio_codes = input_ids ,
96
+ ** self .audio_kwargs ,
97
+ ).audio_values
98
+ output_values = sample if sample .ndim == 3 else sample .unsqueeze (0 )
91
99
else :
92
- sample = input_ids [:, 0 ]
93
- sample_mask = (sample >= self .audio_encoder .config .codebook_size ).sum (dim = (0 , 1 )) == 0
94
- sample = sample [:, :, sample_mask ]
95
- output_values = self .audio_encoder .decode (sample [None , ...], [None ])
100
+ sample = input_ids [:, 0 ] if self .use_4dim_audio_codes else input_ids [0 ]
101
+ sample_mask = ((sample >= self .audio_encoder .config .codebook_size ).sum (dim = (0 , 1 )) == 0 ) if self .use_4dim_audio_codes else ((sample >= self .audio_encoder .config .codebook_size ).sum (dim = 0 ) == 0 )
102
+ sample = sample [:, :, sample_mask ] if self .use_4dim_audio_codes else sample [:, sample_mask ]
103
+ sample = self .audio_encoder .decode (audio_codes = sample [None , ...], ** self .audio_kwargs ).audio_values
104
+ output_values = sample if sample .ndim == 3 else sample .unsqueeze (0 )
96
105
97
- audio_values = output_values . audio_values [0 , 0 ]
106
+ audio_values = output_values [0 , 0 ]
98
107
return audio_values .cpu ().float ().numpy ()
99
108
100
109
def put (self , value ):
0 commit comments