|
25 | 25 | class CtcModel(Model): |
26 | 26 | def __init__(self, **kwargs): |
27 | 27 | super(CtcModel, self).__init__(**kwargs) |
| 28 | + self.time_reduction_factor = 1 |
28 | 29 |
|
29 | 30 | def _build(self, input_shape): |
30 | 31 | features = tf.keras.Input(input_shape, dtype=tf.float32) |
@@ -67,7 +68,7 @@ def recognize_tflite(self, signal): |
67 | 68 | features = self.speech_featurizer.tf_extract(signal) |
68 | 69 | features = tf.expand_dims(features, axis=0) |
69 | 70 | input_length = shape_list(features)[1] |
70 | | - input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor) |
| 71 | + input_length = get_reduced_length(input_length, self.time_reduction_factor) |
71 | 72 | input_length = tf.expand_dims(input_length, axis=0) |
72 | 73 | logits = self(features, training=False) |
73 | 74 | probs = tf.nn.softmax(logits) |
@@ -113,7 +114,7 @@ def recognize_beam_tflite(self, signal): |
113 | 114 | features = self.speech_featurizer.tf_extract(signal) |
114 | 115 | features = tf.expand_dims(features, axis=0) |
115 | 116 | input_length = shape_list(features)[1] |
116 | | - input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor) |
| 117 | + input_length = get_reduced_length(input_length, self.time_reduction_factor) |
117 | 118 | input_length = tf.expand_dims(input_length, axis=0) |
118 | 119 | logits = self(features, training=False) |
119 | 120 | probs = tf.nn.softmax(logits) |
|
0 commit comments