Skip to content

Commit da1c5b9

Browse files
committed
🚀 update demo examples and dependencies
1 parent c2e1ddf commit da1c5b9

File tree

7 files changed

+114
-131
lines changed

7 files changed

+114
-131
lines changed

examples/demonstration/conformer.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,26 @@
1515
import os
1616
import argparse
1717
from tensorflow_asr.utils import setup_environment, setup_devices
18+
from tensorflow_asr.utils.utils import get_reduced_length
1819

1920
setup_environment()
2021
import tensorflow as tf
2122

2223
parser = argparse.ArgumentParser(prog="Conformer non streaming")
2324

24-
parser.add_argument("filename", metavar="FILENAME",
25-
help="audio file to be played back")
25+
parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back")
2626

27-
parser.add_argument("--config", type=str, default=None,
28-
help="Path to conformer config yaml")
27+
parser.add_argument("--config", type=str, default=None, help="Path to conformer config yaml")
2928

30-
parser.add_argument("--saved", type=str, default=None,
31-
help="Path to conformer saved h5 weights")
32-
33-
parser.add_argument("--blank", type=int, default=0,
34-
help="Path to conformer tflite")
29+
parser.add_argument("--saved", type=str, default=None, help="Path to conformer saved h5 weights")
3530

3631
parser.add_argument("--beam_width", type=int, default=0, help="Beam width")
3732

38-
parser.add_argument("--num_rnns", type=int, default=1,
39-
help="Number of RNN layers in prediction network")
40-
41-
parser.add_argument("--nstates", type=int, default=2,
42-
help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)")
43-
44-
parser.add_argument("--statesize", type=int, default=320,
45-
help="Size of RNN state in prediction network")
46-
47-
parser.add_argument("--device", type=int, default=0,
48-
help="Device's id to run test on")
49-
50-
parser.add_argument("--cpu", default=False, action="store_true",
51-
help="Whether to only use cpu")
33+
parser.add_argument("--device", type=int, default=0, help="Device's id to run test on")
5234

53-
parser.add_argument("--subwords", type=str, default=None,
54-
help="Path to file that stores generated subwords")
35+
parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu")
5536

56-
parser.add_argument("--output_name", type=str, default="test",
57-
help="Result filename name prefix")
37+
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")
5838

5939
args = parser.parse_args()
6040

@@ -83,10 +63,12 @@
8363
conformer.add_featurizers(speech_featurizer, text_featurizer)
8464

8565
signal = read_raw_audio(args.filename)
66+
features = speech_featurizer.tf_extract(signal)
67+
input_length = get_reduced_length(tf.shape(features)[0], conformer.time_reduction_factor)
8668

8769
if (args.beam_width):
88-
transcript = conformer.recognize_beam(signal[None, ...])
70+
transcript = conformer.recognize_beam(features[None, ...], input_length[None, ...])
8971
else:
90-
transcript = conformer.recognize(signal[None, ...])
72+
transcript = conformer.recognize(features[None, ...], input_length[None, ...])
9173

9274
tf.print("Transcript:", transcript[0])

examples/demonstration/streaming_tflite_conformer.py

Lines changed: 54 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def recognizer(Q):
9393

9494
def recognize(signal, lastid, states):
9595
if signal.shape[0] < args.blocksize:
96-
signal = np.pad(signal, [[0, args.blocksize - signal.shape[0]]])
96+
signal = tf.pad(signal, [[0, args.blocksize - signal.shape[0]]])
9797
tflitemodel.set_tensor(input_details[0]["index"], signal)
9898
tflitemodel.set_tensor(input_details[1]["index"], lastid)
9999
tflitemodel.set_tensor(input_details[2]["index"], states)
@@ -104,8 +104,8 @@ def recognize(signal, lastid, states):
104104
text = "".join([chr(u) for u in upoints])
105105
return text, lastid, states
106106

107-
lastid = args.blank * np.ones(shape=[], dtype=np.int32)
108-
states = np.zeros(shape=[args.num_rnns, args.nstates, 1, args.statesize], dtype=np.float32)
107+
lastid = args.blank * tf.ones(shape=[], dtype=tf.int32)
108+
states = tf.zeros(shape=[args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32)
109109
transcript = ""
110110

111111
while True:
@@ -122,51 +122,56 @@ def recognize(signal, lastid, states):
122122
tflite_process.start()
123123

124124

125-
def callback(outdata, frames, time, status):
126-
assert frames == args.blocksize
127-
if status.output_underflow:
128-
print('Output underflow: increase blocksize?', file=sys.stderr)
129-
raise sd.CallbackAbort
130-
assert not status
125+
def send(q, Q, E):
126+
def callback(outdata, frames, time, status):
127+
assert frames == args.blocksize
128+
if status.output_underflow:
129+
print('Output underflow: increase blocksize?', file=sys.stderr)
130+
raise sd.CallbackAbort
131+
assert not status
132+
try:
133+
data = q.get_nowait()
134+
Q.put(np.frombuffer(data, dtype=np.float32))
135+
except queue.Empty as e:
136+
print('Buffer is empty: increase buffersize?', file=sys.stderr)
137+
raise sd.CallbackAbort from e
138+
if len(data) < len(outdata):
139+
outdata[:len(data)] = data
140+
outdata[len(data):] = b'\x00' * (len(outdata) - len(data))
141+
raise sd.CallbackStop
142+
else:
143+
outdata[:] = data
144+
131145
try:
132-
data = q.get_nowait()
133-
Q.put(np.frombuffer(data, dtype=np.float32))
134-
except queue.Empty as e:
135-
print('Buffer is empty: increase buffersize?', file=sys.stderr)
136-
raise sd.CallbackAbort from e
137-
if len(data) < len(outdata):
138-
outdata[:len(data)] = data
139-
outdata[len(data):] = b'\x00' * (len(outdata) - len(data))
140-
raise sd.CallbackStop
141-
else:
142-
outdata[:] = data
143-
144-
145-
try:
146-
with sf.SoundFile(args.filename) as f:
147-
for _ in range(args.buffersize):
148-
data = f.buffer_read(args.blocksize, dtype='float32')
149-
if not data:
150-
break
151-
q.put_nowait(data) # Pre-fill queue
152-
stream = sd.RawOutputStream(
153-
samplerate=f.samplerate, blocksize=args.blocksize,
154-
device=args.device, channels=f.channels, dtype='float32',
155-
callback=callback, finished_callback=E.set)
156-
with stream:
157-
timeout = args.blocksize * args.buffersize / f.samplerate
158-
while data:
146+
with sf.SoundFile(args.filename) as f:
147+
for _ in range(args.buffersize):
159148
data = f.buffer_read(args.blocksize, dtype='float32')
160-
q.put(data, timeout=timeout)
161-
E.wait()
162-
163-
except KeyboardInterrupt:
164-
parser.exit('\nInterrupted by user')
165-
except queue.Full:
166-
# A timeout occurred, i.e. there was an error in the callback
167-
parser.exit(1)
168-
except Exception as e:
169-
parser.exit(type(e).__name__ + ': ' + str(e))
170-
171-
tflite_process.join()
172-
tflite_process.close()
149+
if not data:
150+
break
151+
q.put_nowait(data) # Pre-fill queue
152+
stream = sd.RawOutputStream(
153+
samplerate=f.samplerate, blocksize=args.blocksize,
154+
device=args.device, channels=f.channels, dtype='float32',
155+
callback=callback, finished_callback=E.set)
156+
with stream:
157+
timeout = args.blocksize * args.buffersize / f.samplerate
158+
while data:
159+
data = f.buffer_read(args.blocksize, dtype='float32')
160+
q.put(data, timeout=timeout)
161+
E.wait()
162+
163+
except KeyboardInterrupt:
164+
parser.exit('\nInterrupted by user')
165+
except queue.Full:
166+
# A timeout occurred, i.e. there was an error in the callback
167+
parser.exit(1)
168+
except Exception as e:
169+
parser.exit(type(e).__name__ + ': ' + str(e))
170+
171+
172+
send_process = Process(target=send, args=[q, Q, E])
173+
send_process.start()
174+
send_process.join()
175+
send_process.close()
176+
177+
tflite_process.terminate()

setup.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,16 @@
1919

2020
requirements = [
2121
"tensorflow-datasets>=3.2.1,<4.0.0",
22-
"tensorflow-metadata>=0.26.0",
2322
"tensorflow-addons>=0.10.0",
2423
"setuptools>=47.1.1",
25-
"librosa>=0.7.2",
24+
"librosa>=0.8.0",
2625
"soundfile>=0.10.3",
2726
"PyYAML>=5.3.1",
2827
"matplotlib>=3.2.1",
29-
"sox>=1.3.7",
30-
"numba==0.49.1",
31-
"tqdm>=4.51.0",
32-
"colorama>=0.4.3",
33-
"nlpaug>=1.0.1",
28+
"sox>=1.4.1",
29+
"tqdm>=4.54.1",
30+
"colorama>=0.4.4",
31+
"nlpaug>=1.1.1",
3432
]
3533

3634
setuptools.setup(

tensorflow_asr/featurizers/speech_featurizers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,14 +428,12 @@ def tf_extract(self, signal: tf.Tensor) -> tf.Tensor:
428428
elif self.feature_type == "log_gammatone_spectrogram":
429429
features = self.compute_log_gammatone_spectrogram(signal)
430430
else:
431-
raise ValueError("feature_type must be either 'mfcc',"
432-
"'log_mel_spectrogram' or 'spectrogram'")
431+
raise ValueError("feature_type must be either 'mfcc', 'log_mel_spectrogram' or 'spectrogram'")
433432

434433
features = tf.expand_dims(features, axis=-1)
435434

436435
if self.normalize_feature:
437-
features = tf_normalize_audio_features(
438-
features, per_feature=self.normalize_per_feature)
436+
features = tf_normalize_audio_features(features, per_feature=self.normalize_per_feature)
439437

440438
return features
441439

tensorflow_asr/models/ctc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ def recognize(self, features: tf.Tensor, input_length: Optional[tf.Tensor]):
4646
logits = self(features, training=False)
4747
probs = tf.nn.softmax(logits)
4848

49-
def map_fn(prob): return tf.numpy_function(self.__perform_greedy, inp=[prob], Tout=tf.string)
49+
def map_fn(prob): return tf.numpy_function(self._perform_greedy, inp=[prob], Tout=tf.string)
5050

5151
return tf.map_fn(map_fn, probs, fn_output_signature=tf.TensorSpec([], dtype=tf.string))
5252

53-
def __perform_greedy(self, probs: np.ndarray):
53+
def _perform_greedy(self, probs: np.ndarray):
5454
from ctc_decoders import ctc_greedy_decoder
5555
decoded = ctc_greedy_decoder(probs, vocabulary=self.text_featurizer.vocab_array)
5656
return tf.convert_to_tensor(decoded, dtype=tf.string)
@@ -85,11 +85,11 @@ def recognize_beam(self, features: tf.Tensor, input_length: Optional[tf.Tensor],
8585
logits = self(features, training=False)
8686
probs = tf.nn.softmax(logits)
8787

88-
def map_fn(prob): return tf.numpy_function(self.__perform_beam_search, inp=[prob, lm], Tout=tf.string)
88+
def map_fn(prob): return tf.numpy_function(self._perform_beam_search, inp=[prob, lm], Tout=tf.string)
8989

9090
return tf.map_fn(map_fn, probs, dtype=tf.string)
9191

92-
def __perform_beam_search(self, probs: np.ndarray, lm: bool = False):
92+
def _perform_beam_search(self, probs: np.ndarray, lm: bool = False):
9393
from ctc_decoders import ctc_beam_search_decoder
9494
decoded = ctc_beam_search_decoder(
9595
probs_seq=probs,

tensorflow_asr/models/streaming_transducer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ def recognize(self,
266266
self.encoder.get_initial_state(),
267267
input_length=input_length, with_batch=True
268268
)
269-
return self.__perform_greedy_batch(encoded, input_length,
270-
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
269+
return self._perform_greedy_batch(encoded, input_length,
270+
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
271271

272272
def recognize_tflite(self, signal, predicted, encoder_states, prediction_states):
273273
"""
@@ -286,7 +286,7 @@ def recognize_tflite(self, signal, predicted, encoder_states, prediction_states)
286286
"""
287287
features = self.speech_featurizer.tf_extract(signal)
288288
encoded, new_encoder_states = self.encoder_inference(features, encoder_states)
289-
hypothesis = self.__perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states)
289+
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states)
290290
transcript = self.text_featurizer.indices2upoints(hypothesis.prediction)
291291
return (
292292
transcript,
@@ -318,8 +318,8 @@ def recognize_beam(self,
318318
self.encoder.get_initial_state(),
319319
input_length=input_length, with_batch=True
320320
)
321-
return self.__perform_beam_search_batch(encoded, input_length, lm,
322-
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
321+
return self._perform_beam_search_batch(encoded, input_length, lm,
322+
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
323323

324324
# -------------------------------- TFLITE -------------------------------------
325325

0 commit comments

Comments
 (0)