diff --git a/modules.py b/modules.py index 4222d0a..ff8760b 100644 --- a/modules.py +++ b/modules.py @@ -247,7 +247,10 @@ def multihead_attention(queries, # Restore shape outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, C) - + + # Linear projections + outputs = tf.layers.dense(outputs, num_units, activation=tf.nn.relu) # (N, T_q, C) + # Residual connection outputs += queries