Skip to content

Commit 4c9255c

Browse files
committed
🚀 add post joint linear
1 parent 53497a5 commit 4c9255c

File tree

7 files changed

+26
-1
lines changed

7 files changed

+26
-1
lines changed

tensorflow_asr/models/conformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def __init__(self,
391391
joint_dim: int = 1024,
392392
joint_activation: str = "tanh",
393393
prejoint_linear: bool = True,
394+
postjoint_linear: bool = False,
394395
joint_mode: str = "add",
395396
joint_trainable: bool = True,
396397
kernel_regularizer=L2,
@@ -428,6 +429,7 @@ def __init__(self,
428429
joint_dim=joint_dim,
429430
joint_activation=joint_activation,
430431
prejoint_linear=prejoint_linear,
432+
postjoint_linear=postjoint_linear,
431433
joint_mode=joint_mode,
432434
joint_trainable=joint_trainable,
433435
kernel_regularizer=kernel_regularizer,

tensorflow_asr/models/contextnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def __init__(self,
210210
joint_dim: int = 1024,
211211
joint_activation: str = "tanh",
212212
prejoint_linear: bool = True,
213+
postjoint_linear: bool = False,
213214
joint_mode: str = "add",
214215
joint_trainable: bool = True,
215216
kernel_regularizer=L2,
@@ -238,6 +239,7 @@ def __init__(self,
238239
joint_dim=joint_dim,
239240
joint_activation=joint_activation,
240241
prejoint_linear=prejoint_linear,
242+
postjoint_linear=postjoint_linear,
241243
joint_mode=joint_mode,
242244
joint_trainable=joint_trainable,
243245
kernel_regularizer=kernel_regularizer,

tensorflow_asr/models/keras/conformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self,
4343
joint_dim: int = 1024,
4444
joint_activation: str = "tanh",
4545
prejoint_linear: bool = True,
46+
postjoint_linear: bool = False,
4647
joint_mode: str = "add",
4748
joint_trainable: bool = True,
4849
kernel_regularizer=L2,
@@ -80,6 +81,7 @@ def __init__(self,
8081
joint_dim=joint_dim,
8182
joint_activation=joint_activation,
8283
prejoint_linear=prejoint_linear,
84+
postjoint_linear=postjoint_linear,
8385
joint_mode=joint_mode,
8486
joint_trainable=joint_trainable,
8587
kernel_regularizer=kernel_regularizer,

tensorflow_asr/models/keras/contextnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(self,
3838
joint_dim: int = 1024,
3939
joint_activation: str = "tanh",
4040
prejoint_linear: bool = True,
41+
postjoint_linear: bool = False,
4142
joint_mode: str = "add",
4243
joint_trainable: bool = True,
4344
kernel_regularizer=L2,
@@ -66,6 +67,7 @@ def __init__(self,
6667
joint_dim=joint_dim,
6768
joint_activation=joint_activation,
6869
prejoint_linear=prejoint_linear,
70+
postjoint_linear=postjoint_linear,
6971
joint_mode=joint_mode,
7072
joint_trainable=joint_trainable,
7173
kernel_regularizer=kernel_regularizer,

tensorflow_asr/models/keras/streaming_transducer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self,
4040
joint_dim: int = 640,
4141
joint_activation: str = "tanh",
4242
prejoint_linear: bool = True,
43+
postjoint_linear: bool = False,
4344
joint_mode: str = "add",
4445
joint_trainable: bool = True,
4546
kernel_regularizer = None,
@@ -71,6 +72,7 @@ def __init__(self,
7172
joint_dim=joint_dim,
7273
joint_activation=joint_activation,
7374
prejoint_linear=prejoint_linear,
75+
postjoint_linear=postjoint_linear,
7476
joint_mode=joint_mode,
7577
joint_trainable=joint_trainable,
7678
kernel_regularizer=kernel_regularizer,

tensorflow_asr/models/streaming_transducer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def __init__(self,
195195
joint_dim: int = 640,
196196
joint_activation: str = "tanh",
197197
prejoint_linear: bool = True,
198+
postjoint_linear: bool = False,
198199
joint_mode: str = "add",
199200
joint_trainable: bool = True,
200201
kernel_regularizer = None,
@@ -226,6 +227,7 @@ def __init__(self,
226227
joint_dim=joint_dim,
227228
joint_activation=joint_activation,
228229
prejoint_linear=prejoint_linear,
230+
postjoint_linear=postjoint_linear,
229231
joint_mode=joint_mode,
230232
joint_trainable=joint_trainable,
231233
kernel_regularizer=kernel_regularizer,

tensorflow_asr/models/transducer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def __init__(self,
165165
joint_dim: int = 1024,
166166
activation: str = "tanh",
167167
prejoint_linear: bool = True,
168+
postjoint_linear: bool = False,
168169
joint_mode: str = "add",
169170
kernel_regularizer=None,
170171
bias_regularizer=None,
@@ -183,6 +184,7 @@ def __init__(self,
183184
raise ValueError("activation must be either 'linear', 'relu' or 'tanh'")
184185

185186
self.prejoint_linear = prejoint_linear
187+
self.postjoint_linear = postjoint_linear
186188

187189
if self.prejoint_linear:
188190
self.ffn_enc = tf.keras.layers.Dense(
@@ -205,6 +207,13 @@ def __init__(self,
205207
else:
206208
raise ValueError("joint_mode must be either 'add' or 'concat'")
207209

210+
if self.postjoint_linear:
211+
self.ffn = tf.keras.layers.Dense(
212+
joint_dim, name=f"{name}_ffn",
213+
kernel_regularizer=kernel_regularizer,
214+
bias_regularizer=bias_regularizer
215+
)
216+
208217
self.ffn_out = tf.keras.layers.Dense(
209218
vocabulary_size, name=f"{name}_vocab",
210219
kernel_regularizer=kernel_regularizer,
@@ -221,6 +230,8 @@ def call(self, inputs, training=False, **kwargs):
221230
enc_out = self.enc_reshape(enc_out, repeats=tf.shape(pred_out)[1])
222231
pred_out = self.pred_reshape(pred_out, repeats=tf.shape(enc_out)[1])
223232
outputs = self.joint([enc_out, pred_out], training=training)
233+
if self.postjoint_linear:
234+
outputs = self.ffn(outputs, training=training)
224235
outputs = self.activation(outputs, training=training) # => [B, T, U, V]
225236
outputs = self.ffn_out(outputs, training=training)
226237
return outputs
@@ -231,7 +242,7 @@ def get_config(self):
231242
conf.update(self.ffn_out.get_config())
232243
conf.update(self.activation.get_config())
233244
conf.update(self.joint.get_config())
234-
conf.update({"prejoint_linear": self.prejoint_linear})
245+
conf.update({"prejoint_linear": self.prejoint_linear, "postjoint_linear": self.postjoint_linear})
235246
return conf
236247

237248

@@ -253,6 +264,7 @@ def __init__(self,
253264
joint_dim: int = 1024,
254265
joint_activation: str = "tanh",
255266
prejoint_linear: bool = True,
267+
postjoint_linear: bool = False,
256268
joint_mode: str = "add",
257269
joint_trainable: bool = True,
258270
kernel_regularizer=None,
@@ -281,6 +293,7 @@ def __init__(self,
281293
joint_dim=joint_dim,
282294
activation=joint_activation,
283295
prejoint_linear=prejoint_linear,
296+
postjoint_linear=postjoint_linear,
284297
joint_mode=joint_mode,
285298
kernel_regularizer=kernel_regularizer,
286299
bias_regularizer=bias_regularizer,

0 commit comments

Comments
 (0)