@@ -165,6 +165,7 @@ def __init__(self,
165165 joint_dim : int = 1024 ,
166166 activation : str = "tanh" ,
167167 prejoint_linear : bool = True ,
168+ postjoint_linear : bool = False ,
168169 joint_mode : str = "add" ,
169170 kernel_regularizer = None ,
170171 bias_regularizer = None ,
@@ -183,6 +184,7 @@ def __init__(self,
183184 raise ValueError ("activation must be either 'linear', 'relu' or 'tanh'" )
184185
185186 self .prejoint_linear = prejoint_linear
187+ self .postjoint_linear = postjoint_linear
186188
187189 if self .prejoint_linear :
188190 self .ffn_enc = tf .keras .layers .Dense (
@@ -205,6 +207,13 @@ def __init__(self,
205207 else :
206208 raise ValueError ("joint_mode must be either 'add' or 'concat'" )
207209
210+ if self .postjoint_linear :
211+ self .ffn = tf .keras .layers .Dense (
212+ joint_dim , name = f"{ name } _ffn" ,
213+ kernel_regularizer = kernel_regularizer ,
214+ bias_regularizer = bias_regularizer
215+ )
216+
208217 self .ffn_out = tf .keras .layers .Dense (
209218 vocabulary_size , name = f"{ name } _vocab" ,
210219 kernel_regularizer = kernel_regularizer ,
@@ -221,6 +230,8 @@ def call(self, inputs, training=False, **kwargs):
221230 enc_out = self .enc_reshape (enc_out , repeats = tf .shape (pred_out )[1 ])
222231 pred_out = self .pred_reshape (pred_out , repeats = tf .shape (enc_out )[1 ])
223232 outputs = self .joint ([enc_out , pred_out ], training = training )
233+ if self .postjoint_linear :
234+ outputs = self .ffn (outputs , training = training )
224235 outputs = self .activation (outputs , training = training ) # => [B, T, U, V]
225236 outputs = self .ffn_out (outputs , training = training )
226237 return outputs
@@ -231,7 +242,7 @@ def get_config(self):
231242 conf .update (self .ffn_out .get_config ())
232243 conf .update (self .activation .get_config ())
233244 conf .update (self .joint .get_config ())
234- conf .update ({"prejoint_linear" : self .prejoint_linear })
245+ conf .update ({"prejoint_linear" : self .prejoint_linear , "postjoint_linear" : self . postjoint_linear })
235246 return conf
236247
237248
@@ -253,6 +264,7 @@ def __init__(self,
253264 joint_dim : int = 1024 ,
254265 joint_activation : str = "tanh" ,
255266 prejoint_linear : bool = True ,
267+ postjoint_linear : bool = False ,
256268 joint_mode : str = "add" ,
257269 joint_trainable : bool = True ,
258270 kernel_regularizer = None ,
@@ -281,6 +293,7 @@ def __init__(self,
281293 joint_dim = joint_dim ,
282294 activation = joint_activation ,
283295 prejoint_linear = prejoint_linear ,
296+ postjoint_linear = postjoint_linear ,
284297 joint_mode = joint_mode ,
285298 kernel_regularizer = kernel_regularizer ,
286299 bias_regularizer = bias_regularizer ,
0 commit comments