Skip to content

Commit 817c6f0

Browse files
committed
⚡ Fixed tflite conformer
1 parent fbf9757 commit 817c6f0

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tensorflow_asr/models/conformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .layers.subsampling import VggSubsampling, Conv2dSubsampling
2020
from .layers.positional_encoding import PositionalEncoding, PositionalEncodingConcat
2121
from .layers.multihead_attention import MultiHeadAttention, RelPositionMultiHeadAttention
22+
from ..utils.utils import shape_list
2223

2324
L2 = tf.keras.regularizers.l2(1e-6)
2425

@@ -180,14 +181,15 @@ def __init__(self,
180181

181182
def call(self, inputs, training=False, **kwargs):
182183
outputs = self.ln(inputs, training=training)
183-
outputs = tf.expand_dims(outputs, axis=2)
184+
B, T, E = shape_list(outputs)
185+
outputs = tf.reshape(outputs, [B, T, 1, E])
184186
outputs = self.pw_conv_1(outputs, training=training)
185187
outputs = self.glu(outputs)
186188
outputs = self.dw_conv(outputs, training=training)
187189
outputs = self.bn(outputs, training=training)
188190
outputs = self.swish(outputs)
189191
outputs = self.pw_conv_2(outputs, training=training)
190-
outputs = tf.squeeze(outputs, axis=2)
192+
outputs = tf.reshape(outputs, [B, T, E])
191193
outputs = self.do(outputs, training=training)
192194
outputs = self.res_add([inputs, outputs])
193195
return outputs

0 commit comments

Comments
 (0)