@@ -48,6 +48,7 @@ def get_audio_embeddings(
4848 import torchaudio
4949
5050 all_embeddings = []
51+ max_samples = int (self .max_audio_length_seconds * self .sampling_rate )
5152
5253 for batch in tqdm (
5354 inputs ,
@@ -76,22 +77,54 @@ def get_audio_embeddings(
7677 array = resampler (array )
7778
7879 array = array .squeeze ()
80+
81+ # Handle edge case where squeeze results in 0-dim tensor
82+ if array .dim () == 0 :
83+ array = array .unsqueeze (0 )
84+
85+ # Warn and handle empty audio
86+ if array .shape [- 1 ] == 0 :
87+ logger .warning (
88+ f"Empty audio sample at index { idx } , using 1 second of silence."
89+ )
90+ array = torch .zeros (self .sampling_rate ) # 1 second of silence
91+
92+ # Truncate if too long (processor doesn't allow both padding and truncation)
93+ if array .shape [- 1 ] > max_samples :
94+ array = array [:max_samples ]
95+
96+ # Ensure minimum length for encoder (Encodec needs ~320 samples per frame)
97+ # Use 1 second minimum to be safe
98+ min_samples = self .sampling_rate
99+ if array .shape [- 1 ] < min_samples :
100+ padding = torch .zeros (min_samples - array .shape [- 1 ])
101+ array = torch .cat ([array , padding ])
102+
79103 audio_arrays .append (array .numpy ())
80104
81105 with torch .no_grad ():
82- # Process audio through EnCodec's processor
83- max_samples = int (self .max_audio_length_seconds * self .sampling_rate )
84-
85- feature_inputs = self .processor (
106+ # Use processor for batch padding (truncation/min-length done manually above)
107+ processed = self .processor (
86108 raw_audio = audio_arrays ,
87109 sampling_rate = self .sampling_rate ,
110+ padding = True ,
88111 return_tensors = "pt" ,
89- padding = "max_length" ,
90- max_length = max_samples ,
91- ).to (self .device )
112+ )
113+ input_values = processed ["input_values" ].to (self .device )
114+
115+ # Add channel dimension if needed (B, T) -> (B, 1, T)
116+ if input_values .dim () == 2 :
117+ input_values = input_values .unsqueeze (1 )
92118
93119 # Get the latent representations directly from the encoder
94- latent = self .model .encoder (feature_inputs .input_values )
120+ latent = self .model .encoder (input_values )
121+
122+ # Validate latent has time frames
123+ if latent .shape [2 ] == 0 :
124+ raise ValueError (
125+ f"Encodec encoder produced 0 time frames. "
126+ f"Input shape: { input_values .shape } , latent shape: { latent .shape } "
127+ )
95128
96129 # Apply mean pooling over the time dimension to get fixed-size embeddings
97130 embeddings = torch .mean (latent , dim = 2 ) # Average over time dimension
0 commit comments