1
1
import os
2
2
3
3
import keras
4
+ import numpy as np
4
5
import pytest
5
6
6
7
from keras_hub .src .models .moonshine .moonshine_audio_converter import (
@@ -26,9 +27,14 @@ def setUp(self):
26
27
"tokenizer" : self .tokenizer ,
27
28
"decoder_sequence_length" : 8 ,
28
29
}
30
+ # NOTE: Since keras.ops.convert_to_tensor() does not support
31
+ # dtype="string" for the JAX and PyTorch backends, the only way to pass
32
+ # inputs that aren't a mix of tensors and non-tensors is to use a
33
+ # library-specific function. Using np.random.normal here as a substitute
34
+ # to a librosa.load() call.
29
35
self .input_data = (
30
36
{
31
- "audio" : keras .random .normal ((1 , 16000 , 1 )),
37
+ "audio" : np .random .normal (size = (1 , 16000 , 1 )),
32
38
"text" : ["the quick brown fox" ],
33
39
},
34
40
)
@@ -76,6 +82,24 @@ def test_generate_postprocess(self):
76
82
self .assertIsInstance (output , list )
77
83
self .assertIsInstance (output [0 ], str )
78
84
85
+ def test_generate_postprocess_batched (self ):
86
+ preprocessor = MoonshineAudioToTextPreprocessor (** self .init_kwargs )
87
+ batch_size = 3
88
+ sequence_length = 5
89
+ input_data = {
90
+ "decoder_token_ids" : keras .ops .ones (
91
+ (batch_size , sequence_length ), dtype = "int32"
92
+ ),
93
+ "decoder_padding_mask" : keras .ops .ones (
94
+ (batch_size , sequence_length )
95
+ ),
96
+ }
97
+ output = preprocessor .generate_postprocess (input_data )
98
+ self .assertIsInstance (output , list )
99
+ self .assertEqual (len (output ), batch_size )
100
+ for item in output :
101
+ self .assertIsInstance (item , str )
102
+
79
103
@pytest .mark .extra_large
80
104
def test_all_presets (self ):
81
105
for preset in MoonshineAudioToTextPreprocessor .presets :
0 commit comments