|
19 | 19 | from .layers.subsampling import VggSubsampling, Conv2dSubsampling |
20 | 20 | from .layers.positional_encoding import PositionalEncoding, PositionalEncodingConcat |
21 | 21 | from .layers.multihead_attention import MultiHeadAttention, RelPositionMultiHeadAttention |
| 22 | +from ..utils.utils import shape_list |
22 | 23 |
|
23 | 24 | L2 = tf.keras.regularizers.l2(1e-6) |
24 | 25 |
|
@@ -180,14 +181,15 @@ def __init__(self, |
180 | 181 |
|
181 | 182 | def call(self, inputs, training=False, **kwargs): |
182 | 183 | outputs = self.ln(inputs, training=training) |
183 | | - outputs = tf.expand_dims(outputs, axis=2) |
| 184 | + B, T, E = shape_list(outputs) |
| 185 | + outputs = tf.reshape(outputs, [B, T, 1, E]) |
184 | 186 | outputs = self.pw_conv_1(outputs, training=training) |
185 | 187 | outputs = self.glu(outputs) |
186 | 188 | outputs = self.dw_conv(outputs, training=training) |
187 | 189 | outputs = self.bn(outputs, training=training) |
188 | 190 | outputs = self.swish(outputs) |
189 | 191 | outputs = self.pw_conv_2(outputs, training=training) |
190 | | - outputs = tf.squeeze(outputs, axis=2) |
| 192 | + outputs = tf.reshape(outputs, [B, T, E]) |
191 | 193 | outputs = self.do(outputs, training=training) |
192 | 194 | outputs = self.res_add([inputs, outputs]) |
193 | 195 | return outputs |
|
0 commit comments