Skip to content

Commit 2b19a06

Browse files
authored
Fix gemma3n feature extractor's incorrect squeeze (#39919)
* fix gemma3n squeeze Signed-off-by: Isotr0py <[email protected]> * add regression test Signed-off-by: Isotr0py <[email protected]> --------- Signed-off-by: Isotr0py <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent 555cbf5 commit 2b19a06

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/transformers/models/gemma3n/feature_extraction_gemma3n.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _extract_spectrogram(self, waveform: np.ndarray, attention_mask: np.ndarray)
261261
if self.per_bin_stddev is not None:
262262
log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting
263263

264-
mel_spectrogram = log_mel_spec.squeeze()
264+
mel_spectrogram = log_mel_spec.squeeze(0)
265265
mask = attention_mask[:: self.hop_length].astype(bool)
266266
# TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why???
267267
return mel_spectrogram, mask[: mel_spectrogram.shape[0]]

tests/models/gemma3n/test_feature_extraction_gemma3n.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,25 @@ def test_call(self, audio_inputs, test_truncation=False):
228228
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
229229
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
230230

231+
def test_audio_features_attn_mask_consistent(self):
232+
# regression test for https://github.com/huggingface/transformers/issues/39911
233+
# Test input_features and input_features_mask have consistent shape
234+
np.random.seed(42)
235+
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
236+
for i in [512, 640, 1024]:
237+
audio = np.random.randn(i)
238+
mm_data = {
239+
"raw_speech": [audio],
240+
"sampling_rate": 16000,
241+
}
242+
inputs = feature_extractor(**mm_data, return_tensors="np")
243+
out = inputs["input_features"]
244+
mask = inputs["input_features_mask"]
245+
246+
assert out.ndim == 3
247+
assert mask.ndim == 2
248+
assert out.shape[:2] == mask.shape[:2]
249+
231250
def test_dither(self):
232251
np.random.seed(42) # seed the dithering randn()
233252

0 commit comments

Comments
 (0)