Skip to content

Commit f79b3f6

Browse files
committed
✍️ update ctc model
1 parent 2c90f67 commit f79b3f6

File tree

1 file changed

+8
-6
lines changed
  • tensorflow_asr/models/ctc

1 file changed

+8
-6
lines changed

tensorflow_asr/models/ctc/ctc.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def call(self, inputs, training=False, **kwargs):
7676

7777
@tf.function
7878
def recognize(self, inputs: Dict[str, tf.Tensor]):
79-
logits = self(inputs["inputs"], training=False)
80-
probs = tf.nn.softmax(logits)
79+
logits = self(inputs, training=False)
80+
probs = tf.nn.softmax(logits["logits"])
8181

8282
def map_fn(prob): return tf.numpy_function(self._perform_greedy, inp=[prob], Tout=tf.string)
8383

@@ -102,7 +102,8 @@ def recognize_tflite(self, signal):
102102
input_length = shape_util.shape_list(features)[1]
103103
input_length = math_util.get_reduced_length(input_length, self.time_reduction_factor)
104104
input_length = tf.expand_dims(input_length, axis=0)
105-
logits = self(features, training=False)
105+
logits = self.encoder(features, training=False)
106+
logits = self.decoder(logits, training=False)
106107
probs = tf.nn.softmax(logits)
107108
decoded = tf.keras.backend.ctc_decode(
108109
y_pred=probs, input_length=input_length, greedy=True
@@ -115,8 +116,8 @@ def recognize_tflite(self, signal):
115116

116117
@tf.function
117118
def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False):
118-
logits = self(inputs["inputs"], training=False)
119-
probs = tf.nn.softmax(logits)
119+
logits = self(inputs, training=False)
120+
probs = tf.nn.softmax(logits["logits"])
120121

121122
def map_fn(prob): return tf.numpy_function(self._perform_beam_search, inp=[prob, lm], Tout=tf.string)
122123

@@ -148,7 +149,8 @@ def recognize_beam_tflite(self, signal):
148149
input_length = shape_util.shape_list(features)[1]
149150
input_length = math_util.get_reduced_length(input_length, self.time_reduction_factor)
150151
input_length = tf.expand_dims(input_length, axis=0)
151-
logits = self(features, training=False)
152+
logits = self.encoder(features, training=False)
153+
logits = self.decoder(logits, training=False)
152154
probs = tf.nn.softmax(logits)
153155
decoded = tf.keras.backend.ctc_decode(
154156
y_pred=probs, input_length=input_length, greedy=False,

0 commit comments

Comments
 (0)