Skip to content

Commit 86e8c43

Browse files
authored
Merge pull request #115 from TensorSpeech/fix/tflite
Fix TFLite Conversion and Interpretation
2 parents 304330a + 6113267 commit 86e8c43

File tree

10 files changed

+239
-112
lines changed

10 files changed

+239
-112
lines changed

examples/conformer/tflite_conformer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,10 @@
5454
conformer.summary(line_length=150)
5555
conformer.add_featurizers(speech_featurizer, text_featurizer)
5656

57-
concrete_func = conformer.make_tflite_function(greedy=True).get_concrete_function()
57+
concrete_func = conformer.make_tflite_function().get_concrete_function()
5858
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
5959
converter.optimizations = [tf.lite.Optimize.DEFAULT]
60-
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
61-
tf.lite.OpsSet.SELECT_TF_OPS]
60+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
6261
tflite_model = converter.convert()
6362

6463
if not os.path.exists(os.path.dirname(args.output)):

examples/conformer/tflite_subword_conformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
conformer.summary(line_length=150)
6363
conformer.add_featurizers(speech_featurizer, text_featurizer)
6464

65-
concrete_func = conformer.make_tflite_function(greedy=True).get_concrete_function()
65+
concrete_func = conformer.make_tflite_function().get_concrete_function()
6666
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
6767
converter.experimental_new_converter = True
6868
converter.optimizations = [tf.lite.Optimize.DEFAULT]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
setuptools.setup(
3737
name="TensorFlowASR",
38-
version="0.6.3",
38+
version="0.6.4",
3939
author="Huy Le Nguyen",
4040
author_email="[email protected]",
4141
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/models/contextnet.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
""" Ref: https://github.com/iankur/ContextNet """
1515

16-
from typing import List, Optional
16+
from typing import List
1717
import tensorflow as tf
1818
from .transducer import Transducer
1919
from ..utils.utils import merge_two_last_dims, get_reduced_length
@@ -245,13 +245,94 @@ def call(self, inputs, training=False, **kwargs):
245245
outputs = self.joint_net([enc, pred], training=training, **kwargs)
246246
return outputs
247247

248-
def encoder_inference(self,
249-
features: tf.Tensor,
250-
input_length: Optional[tf.Tensor] = None,
251-
with_batch: bool = False):
248+
def encoder_inference(self, features: tf.Tensor, input_length: tf.Tensor):
252249
with tf.name_scope(f"{self.name}_encoder"):
253-
if with_batch: return self.encoder([features, input_length], training=False)
254250
input_length = tf.expand_dims(tf.shape(features)[0], axis=0)
255251
outputs = tf.expand_dims(features, axis=0)
256252
outputs = self.encoder([outputs, input_length], training=False)
257253
return tf.squeeze(outputs, axis=0)
254+
255+
# -------------------------------- GREEDY -------------------------------------
256+
257+
@tf.function
258+
def recognize(self,
259+
features: tf.Tensor,
260+
input_length: tf.Tensor,
261+
parallel_iterations: int = 10,
262+
swap_memory: bool = True):
263+
"""
264+
RNN Transducer Greedy decoding
265+
Args:
266+
features (tf.Tensor): a batch of padded extracted features
267+
268+
Returns:
269+
tf.Tensor: a batch of decoded transcripts
270+
"""
271+
encoded = self.encoder([features, input_length], training=False)
272+
return self._perform_greedy_batch(encoded, input_length,
273+
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
274+
275+
def recognize_tflite(self, signal, predicted, prediction_states):
276+
"""
277+
Function to convert to tflite using greedy decoding (default streaming mode)
278+
Args:
279+
signal: tf.Tensor with shape [None] indicating a single audio signal
280+
predicted: last predicted character with shape []
281+
prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P]
282+
283+
Return:
284+
transcript: tf.Tensor of Unicode Code Points with shape [None] and dtype tf.int32
285+
predicted: last predicted character with shape []
286+
encoder_states: lastest encoder states with shape [num_rnns, 1 or 2, 1, P]
287+
prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P]
288+
"""
289+
features = self.speech_featurizer.tf_extract(signal)
290+
encoded = self.encoder_inference(features, tf.shape(features)[0])
291+
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states)
292+
transcript = self.text_featurizer.indices2upoints(hypothesis.prediction)
293+
return transcript, hypothesis.index, hypothesis.states
294+
295+
def recognize_tflite_with_timestamp(self, signal, predicted, states):
296+
features = self.speech_featurizer.tf_extract(signal)
297+
encoded = self.encoder_inference(features, tf.shape(features)[0])
298+
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states)
299+
indices = self.text_featurizer.normalize_indices(hypothesis.prediction)
300+
upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length]
301+
302+
num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32)
303+
total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step
304+
305+
stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32)
306+
stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32)
307+
308+
etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32)
309+
etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32)
310+
311+
non_blank = tf.where(tf.not_equal(upoints, 0))
312+
non_blank_transcript = tf.gather_nd(upoints, non_blank)
313+
non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
314+
non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
315+
316+
return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, hypothesis.states
317+
318+
# -------------------------------- BEAM SEARCH -------------------------------------
319+
320+
@tf.function
321+
def recognize_beam(self,
322+
features: tf.Tensor,
323+
input_length: tf.Tensor,
324+
lm: bool = False,
325+
parallel_iterations: int = 10,
326+
swap_memory: bool = True):
327+
"""
328+
RNN Transducer Beam Search
329+
Args:
330+
features (tf.Tensor): a batch of padded extracted features
331+
lm (bool, optional): whether to use language model. Defaults to False.
332+
333+
Returns:
334+
tf.Tensor: a batch of decoded transcripts
335+
"""
336+
encoded = self.encoder([features, input_length], training=False)
337+
return self._perform_beam_search_batch(encoded, input_length, lm,
338+
parallel_iterations=parallel_iterations, swap_memory=swap_memory)

tensorflow_asr/models/streaming_transducer.py

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
""" http://arxiv.org/abs/1811.06621 """
1515

16-
from typing import Optional
1716
import tensorflow as tf
1817

1918
from .layers.subsampling import TimeReduction
@@ -225,24 +224,18 @@ def __init__(self,
225224
)
226225
self.time_reduction_factor = self.encoder.time_reduction_factor
227226

228-
def encoder_inference(self,
229-
features: tf.Tensor,
230-
states: tf.Tensor,
231-
input_length: Optional[tf.Tensor] = None,
232-
with_batch: bool = False):
227+
def encoder_inference(self, features: tf.Tensor, states: tf.Tensor):
233228
"""Infer function for encoder (or encoders)
234229
235230
Args:
236231
features (tf.Tensor): features with shape [T, F, C]
237232
states (tf.Tensor): previous states of encoders with shape [num_rnns, 1 or 2, 1, P]
238-
with_batch (bool): indicates whether the features included batch dim or not
239233
240234
Returns:
241235
tf.Tensor: output of encoders with shape [T, E]
242236
tf.Tensor: states of encoders with shape [num_rnns, 1 or 2, 1, P]
243237
"""
244238
with tf.name_scope(f"{self.name}_encoder"):
245-
if with_batch: return self.encoder.recognize(features, states)
246239
outputs = tf.expand_dims(features, axis=0)
247240
outputs, new_states = self.encoder.recognize(outputs, states)
248241
return tf.squeeze(outputs, axis=0), new_states
@@ -263,11 +256,7 @@ def recognize(self,
263256
Returns:
264257
tf.Tensor: a batch of decoded transcripts
265258
"""
266-
encoded, _ = self.encoder_inference(
267-
features,
268-
self.encoder.get_initial_state(),
269-
input_length=input_length, with_batch=True
270-
)
259+
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state())
271260
return self._perform_greedy_batch(encoded, input_length,
272261
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
273262

@@ -290,12 +279,7 @@ def recognize_tflite(self, signal, predicted, encoder_states, prediction_states)
290279
encoded, new_encoder_states = self.encoder_inference(features, encoder_states)
291280
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states)
292281
transcript = self.text_featurizer.indices2upoints(hypothesis.prediction)
293-
return (
294-
transcript,
295-
hypothesis.prediction[-1],
296-
new_encoder_states,
297-
hypothesis.states
298-
)
282+
return transcript, hypothesis.index, new_encoder_states, hypothesis.states
299283

300284
def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, prediction_states):
301285
features = self.speech_featurizer.tf_extract(signal)
@@ -318,14 +302,7 @@ def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, pre
318302
non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
319303
non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
320304

321-
return (
322-
non_blank_transcript,
323-
non_blank_stime,
324-
non_blank_etime,
325-
hypothesis.prediction,
326-
new_encoder_states,
327-
hypothesis.states
328-
)
305+
return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, new_encoder_states, hypothesis.states
329306

330307
# -------------------------------- BEAM SEARCH -------------------------------------
331308

@@ -345,11 +322,7 @@ def recognize_beam(self,
345322
Returns:
346323
tf.Tensor: a batch of decoded transcripts
347324
"""
348-
encoded, _ = self.encoder_inference(
349-
features,
350-
self.encoder.get_initial_state(),
351-
input_length=input_length, with_batch=True
352-
)
325+
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state())
353326
return self._perform_beam_search_batch(encoded, input_length, lm,
354327
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
355328

tensorflow_asr/models/transducer.py

Lines changed: 31 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
""" https://arxiv.org/pdf/1811.06621.pdf """
1515

1616
import collections
17-
from typing import Optional
1817
import tensorflow as tf
1918

2019
from . import Model
@@ -285,22 +284,16 @@ def call(self, inputs, training=False, **kwargs):
285284
outputs = self.joint_net([enc, pred], training=training, **kwargs)
286285
return outputs
287286

288-
def encoder_inference(self,
289-
features: tf.Tensor,
290-
input_length: Optional[tf.Tensor] = None,
291-
with_batch: Optional[bool] = False):
287+
def encoder_inference(self, features: tf.Tensor):
292288
"""Infer function for encoder (or encoders)
293289
294290
Args:
295291
features (tf.Tensor): features with shape [T, F, C]
296-
input_length (tf.Tensor): optional features length with shape []
297-
with_batch (bool): indicates whether the features included batch dim or not
298292
299293
Returns:
300294
tf.Tensor: output of encoders with shape [T, E]
301295
"""
302296
with tf.name_scope(f"{self.name}_encoder"):
303-
if with_batch: return self.encoder(features, training=False)
304297
outputs = tf.expand_dims(features, axis=0)
305298
outputs = self.encoder(outputs, training=False)
306299
return tf.squeeze(outputs, axis=0)
@@ -321,7 +314,7 @@ def decoder_inference(self, encoded: tf.Tensor, predicted: tf.Tensor, states: tf
321314
predicted = tf.reshape(predicted, [1, 1]) # [] => [1, 1]
322315
y, new_states = self.predict_net.recognize(predicted, states) # [1, 1, P], states
323316
ytu = tf.nn.log_softmax(self.joint_net([encoded, y], training=False)) # [1, 1, V]
324-
ytu = tf.squeeze(ytu, axis=None) # [1, 1, V] => [V]
317+
ytu = tf.reshape(ytu, shape=[-1]) # [1, 1, V] => [V]
325318
return ytu, new_states
326319

327320
def get_config(self):
@@ -347,7 +340,7 @@ def recognize(self,
347340
Returns:
348341
tf.Tensor: a batch of decoded transcripts
349342
"""
350-
encoded = self.encoder_inference(features, input_length, with_batch=True)
343+
encoded = self.encoder(features, training=True)
351344
return self._perform_greedy_batch(encoded, input_length,
352345
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
353346

@@ -368,11 +361,7 @@ def recognize_tflite(self, signal, predicted, states):
368361
encoded = self.encoder_inference(features)
369362
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states)
370363
transcript = self.text_featurizer.indices2upoints(hypothesis.prediction)
371-
return (
372-
transcript,
373-
hypothesis.prediction[-1],
374-
hypothesis.states
375-
)
364+
return transcript, hypothesis.index, hypothesis.states
376365

377366
def recognize_tflite_with_timestamp(self, signal, predicted, states):
378367
features = self.speech_featurizer.tf_extract(signal)
@@ -395,7 +384,7 @@ def recognize_tflite_with_timestamp(self, signal, predicted, states):
395384
non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
396385
non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
397386

398-
return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.prediction, hypothesis.states
387+
return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, hypothesis.states
399388

400389
def _perform_greedy_batch(self,
401390
encoded: tf.Tensor,
@@ -450,48 +439,47 @@ def _perform_greedy(self,
450439
total = encoded_length
451440

452441
hypothesis = Hypothesis(
453-
index=tf.constant(self.text_featurizer.blank, dtype=tf.int32),
454-
prediction=tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank,
442+
index=predicted,
443+
prediction=tf.TensorArray(
444+
dtype=tf.int32, size=total, dynamic_size=False,
445+
clear_after_read=False, element_shape=tf.TensorShape([])
446+
),
455447
states=states
456448
)
457449

458-
def condition(time, total, encoded, hypothesis): return tf.less(time, total)
450+
def condition(_time, _total, _encoded, _hypothesis): return tf.less(_time, _total)
459451

460-
def body(time, total, encoded, hypothesis):
461-
ytu, states = self.decoder_inference(
452+
def body(_time, _total, _encoded, _hypothesis):
453+
ytu, _states = self.decoder_inference(
462454
# avoid using [index] in tflite
463-
encoded=tf.gather_nd(encoded, tf.expand_dims(time, axis=-1)),
464-
predicted=hypothesis.index,
465-
states=hypothesis.states
455+
encoded=tf.gather_nd(_encoded, tf.reshape(_time, shape=[1])),
456+
predicted=_hypothesis.index,
457+
states=_hypothesis.states
466458
)
467-
predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []
459+
_predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []
468460

469-
index, predict, states = tf.cond(
470-
tf.equal(predict, self.text_featurizer.blank),
471-
true_fn=lambda: (hypothesis.index, predict, hypothesis.states),
472-
false_fn=lambda: (predict, predict, states) # update if the new prediction is a non-blank
473-
)
461+
# something is wrong with tflite that drop support for tf.cond
462+
# def equal_blank_fn(): return _hypothesis.index, _hypothesis.states
463+
# def non_equal_blank_fn(): return _predict, _states # update if the new prediction is a non-blank
464+
# _index, _states = tf.cond(tf.equal(_predict, blank), equal_blank_fn, non_equal_blank_fn)
474465

475-
hypothesis = Hypothesis(
476-
index=index,
477-
prediction=tf.tensor_scatter_nd_update(
478-
hypothesis.prediction,
479-
indices=tf.reshape(time, [1, 1]),
480-
updates=tf.expand_dims(predict, axis=-1)
481-
),
482-
states=states
483-
)
466+
_equal = tf.equal(_predict, self.text_featurizer.blank)
467+
_index = tf.where(_equal, _hypothesis.index, _predict)
468+
_states = tf.where(_equal, _hypothesis.states, _states)
469+
470+
_prediction = _hypothesis.prediction.write(_time, _predict)
471+
_hypothesis = Hypothesis(index=_index, prediction=_prediction, states=_states)
484472

485-
return time + 1, total, encoded, hypothesis
473+
return _time + 1, _total, _encoded, _hypothesis
486474

487-
time, total, encoded, hypothesis = tf.while_loop(
475+
_, _, _, hypothesis = tf.while_loop(
488476
condition, body,
489477
loop_vars=[time, total, encoded, hypothesis],
490478
parallel_iterations=parallel_iterations,
491479
swap_memory=swap_memory
492480
)
493481

494-
return hypothesis
482+
return Hypothesis(index=hypothesis.index, prediction=hypothesis.prediction.stack(), states=hypothesis.states)
495483

496484
# -------------------------------- BEAM SEARCH -------------------------------------
497485

@@ -511,7 +499,7 @@ def recognize_beam(self,
511499
Returns:
512500
tf.Tensor: a batch of decoded transcripts
513501
"""
514-
encoded = self.encoder_inference(features, input_length, with_batch=True)
502+
encoded = self.encoder(features, training=True)
515503
return self._perform_beam_search_batch(encoded, input_length, lm,
516504
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
517505

0 commit comments

Comments
 (0)