@@ -232,24 +232,24 @@ def __init__(
232232 bvs = random .normal (subkeys [13 ], (1 , learnable_query_dim )) * sigma
233233 Wouts = random .normal (subkeys [14 ], (learnable_query_dim , learnable_query_dim )) * sigma
234234 bouts = random .normal (subkeys [15 ], (1 , learnable_query_dim )) * sigma
235- Wlnattn_mu = jnp .zeros ((1 , learnable_query_dim ))
236- Wlnattn_scale = jnp .ones ((1 , learnable_query_dim ))
235+ Wlnattn_mu = jnp .zeros ((1 , learnable_query_dim )) ## LN parameter (applied to output of attention)
236+ Wlnattn_scale = jnp .ones ((1 , learnable_query_dim )) ## LN parameter (applied to output of attention)
237237 self_attn_params = (Wqs , bqs , Wks , bks , Wvs , bvs , Wouts , bouts , Wlnattn_mu , Wlnattn_scale )
238238 learnable_query = jnp .zeros ((batch_size , 1 , learnable_query_dim )) # (B, T, D)
239239 self .mask = np .zeros ((batch_size , target_seq_length , source_seq_length )).astype (bool ) ## mask tensor
240240 ## MLP parameters
241241 Whid1 = random .normal (subkeys [16 ], (learnable_query_dim , learnable_query_dim )) * sigma
242242 bhid1 = random .normal (subkeys [17 ], (1 , learnable_query_dim )) * sigma
243- Wln_mu1 = jnp .zeros ((1 , learnable_query_dim ))
244- Wln_scale1 = jnp .ones ((1 , learnable_query_dim ))
243+ Wln_mu1 = jnp .zeros ((1 , learnable_query_dim )) ## LN parameter
244+ Wln_scale1 = jnp .ones ((1 , learnable_query_dim )) ## LN parameter
245245 Whid2 = random .normal (subkeys [18 ], (learnable_query_dim , learnable_query_dim * 4 )) * sigma
246246 bhid2 = random .normal (subkeys [19 ], (1 , learnable_query_dim * 4 )) * sigma
247- Wln_mu2 = jnp .zeros ((1 , learnable_query_dim ))
248- Wln_scale2 = jnp .ones ((1 , learnable_query_dim ))
247+ Wln_mu2 = jnp .zeros ((1 , learnable_query_dim )) ## LN parameter
248+ Wln_scale2 = jnp .ones ((1 , learnable_query_dim )) ## LN parameter
249249 Whid3 = random .normal (subkeys [20 ], (learnable_query_dim * 4 , learnable_query_dim )) * sigma
250250 bhid3 = random .normal (subkeys [21 ], (1 , learnable_query_dim )) * sigma
251- Wln_mu3 = jnp .zeros ((1 , learnable_query_dim * 4 ))
252- Wln_scale3 = jnp .ones ((1 , learnable_query_dim * 4 ))
251+ Wln_mu3 = jnp .zeros ((1 , learnable_query_dim * 4 )) ## LN parameter
252+ Wln_scale3 = jnp .ones ((1 , learnable_query_dim * 4 )) ## LN parameter
253253 Wy = random .normal (subkeys [22 ], (learnable_query_dim , out_dim )) * sigma
254254 by = random .normal (subkeys [23 ], (1 , out_dim )) * sigma
255255 mlp_params = (Whid1 , bhid1 , Wln_mu1 , Wln_scale1 , Whid2 , bhid2 , Wln_mu2 , Wln_scale2 , Whid3 , bhid3 , Wln_mu3 , Wln_scale3 , Wy , by )
0 commit comments