Skip to content

Commit 96da1e2

Browse files
authored
Merge pull request #70 from TensorSpeech/fix/transducer
Fix typo and format for transducer
2 parents 189c7d7 + d750838 commit 96da1e2

File tree

5 files changed

+15
-9
lines changed

5 files changed

+15
-9
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ Session.vim
99
.idea
1010
.vscode
1111
__pycache__
12+
.pytest*
1213
venv

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
[flake8]
22
ignore = E402,E701,E702,E704,E251
3-
max-line-length = 150
3+
max-line-length = 127
44

55
[pep8]
66
ignore = E402,E701,E702,E704,E251
7-
max-line-length = 150
7+
max-line-length = 127
88
indent-size = 4

tensorflow_asr/models/streaming_transducer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def recognize(self, signals):
257257
"""
258258
def execute(signal: tf.Tensor):
259259
features = self.speech_featurizer.tf_extract(signal)
260-
encoded, _ = self.encoder_inference(features, self.encoder.get_initial_states())
260+
encoded, _ = self.encoder_inference(features, self.encoder.get_initial_state())
261261
hypothesis = self.perform_greedy(
262262
encoded,
263263
predicted=tf.constant(self.text_featurizer.blank, dtype=tf.int32),
@@ -310,10 +310,13 @@ def recognize_beam(self, signals, lm=False):
310310
"""
311311
def execute(signal: tf.Tensor):
312312
features = self.speech_featurizer.tf_extract(signal)
313-
encoded, _ = self.encoder_inference(features, self.encoder.get_initial_states())
313+
encoded, _ = self.encoder_inference(features, self.encoder.get_initial_state())
314314
hypothesis = self.perform_beam_search(encoded, lm)
315-
prediction = tf.map_fn(lambda x: tf.strings.to_number(x, tf.int32),
316-
tf.strings.split(hypothesis.prediction), fn_output_signature=tf.TensorSpec([], dtype=tf.int32))
315+
prediction = tf.map_fn(
316+
lambda x: tf.strings.to_number(x, tf.int32),
317+
tf.strings.split(hypothesis.prediction),
318+
fn_output_signature=tf.TensorSpec([], dtype=tf.int32)
319+
)
317320
transcripts = self.text_featurizer.iextract(tf.expand_dims(prediction, axis=0))
318321
return tf.squeeze(transcripts) # reshape from [1] to []
319322

tensorflow_asr/models/transducer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,11 @@ def execute(signal: tf.Tensor):
451451
features = self.speech_featurizer.tf_extract(signal)
452452
encoded = self.encoder_inference(features)
453453
hypothesis = self.perform_beam_search(encoded, lm)
454-
prediction = tf.map_fn(lambda x: tf.strings.to_number(x, tf.int32),
455-
tf.strings.split(hypothesis.prediction), fn_output_signature=tf.TensorSpec([], dtype=tf.int32))
454+
prediction = tf.map_fn(
455+
lambda x: tf.strings.to_number(x, tf.int32),
456+
tf.strings.split(hypothesis.prediction),
457+
fn_output_signature=tf.TensorSpec([], dtype=tf.int32)
458+
)
456459
transcripts = self.text_featurizer.iextract(tf.expand_dims(prediction, axis=0))
457460
return tf.squeeze(transcripts) # reshape from [1] to []
458461

tests/plot_learning_rate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import math
1615
import tensorflow as tf
1716
import matplotlib.pyplot as plt
1817
from tensorflow_asr.optimizers.schedules import SANSchedule, TransformerSchedule

0 commit comments

Comments
 (0)