Skip to content

Commit 7f1f011

Browse files
Fix batch preprocessing bug in Moonshine generation (#2266)
* bug fix: To ragged in batched generation if available * fix: Remove conditional * feat: Add batched post-processing test coverage
1 parent 2d79a25 commit 7f1f011

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,4 +266,7 @@ def generate_postprocess(self, x):
266266
and 0 <= token < vocab_size
267267
]
268268
processed_sequences.append(filtered_tokens)
269+
processed_sequences = tf.ragged.constant(
270+
processed_sequences, dtype=tf.int32
271+
)
269272
return self.tokenizer.detokenize(processed_sequences)

keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor_test.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
import keras
4+
import numpy as np
45
import pytest
56

67
from keras_hub.src.models.moonshine.moonshine_audio_converter import (
@@ -26,9 +27,14 @@ def setUp(self):
2627
"tokenizer": self.tokenizer,
2728
"decoder_sequence_length": 8,
2829
}
30+
# NOTE: Since keras.ops.convert_to_tensor() does not support
31+
# dtype="string" for the JAX and PyTorch backends, the only way to pass
32+
# inputs that aren't a mix of tensors and non-tensors is to use a
33+
# library-specific function. Using np.random.normal here as a substitute
34+
# to a librosa.load() call.
2935
self.input_data = (
3036
{
31-
"audio": keras.random.normal((1, 16000, 1)),
37+
"audio": np.random.normal(size=(1, 16000, 1)),
3238
"text": ["the quick brown fox"],
3339
},
3440
)
@@ -76,6 +82,24 @@ def test_generate_postprocess(self):
7682
self.assertIsInstance(output, list)
7783
self.assertIsInstance(output[0], str)
7884

85+
def test_generate_postprocess_batched(self):
86+
preprocessor = MoonshineAudioToTextPreprocessor(**self.init_kwargs)
87+
batch_size = 3
88+
sequence_length = 5
89+
input_data = {
90+
"decoder_token_ids": keras.ops.ones(
91+
(batch_size, sequence_length), dtype="int32"
92+
),
93+
"decoder_padding_mask": keras.ops.ones(
94+
(batch_size, sequence_length)
95+
),
96+
}
97+
output = preprocessor.generate_postprocess(input_data)
98+
self.assertIsInstance(output, list)
99+
self.assertEqual(len(output), batch_size)
100+
for item in output:
101+
self.assertIsInstance(item, str)
102+
79103
@pytest.mark.extra_large
80104
def test_all_presets(self):
81105
for preset in MoonshineAudioToTextPreprocessor.presets:

0 commit comments

Comments
 (0)