@@ -102,10 +102,11 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
102102 Wqs , bqs , Wks , bks , Wvs , bvs , Wouts , bouts , Wlnattn_mu ,\
103103 Wlnattn_scale , Whid1 , bhid1 , Wln_mu1 , Wln_scale1 , Whid2 ,\
104104 bhid2 , Wln_mu2 , Wln_scale2 , Whid3 , bhid3 , Wln_mu3 , Wln_scale3 ,\
105- Wy , by , ln_in_mu , ln_in_scale = params
105+ Wy , by , ln_in_mu , ln_in_scale , ln_in_mu2 , ln_in_scale2 = params
106106 cross_attn_params = (Wq , bq , Wk , bk , Wv , bv , Wout , bout )
107107 if use_LN_input :
108108 learnable_query = layer_normalize (learnable_query , ln_in_mu , ln_in_scale )
109+ encodings = layer_normalize (encodings , ln_in_mu2 , ln_in_scale2 )
109110 features = cross_attention (cross_attn_params , learnable_query , encodings , mask , n_heads , dropout )
110111 # Perform a single self-attention block here
111112 # Self-Attention
@@ -261,7 +262,9 @@ def __init__(
261262 # Finally, define ln for the input to the attention
262263 ln_in_mu = jnp .zeros ((1 , learnable_query_dim )) ## LN parameter
263264 ln_in_scale = jnp .ones ((1 , learnable_query_dim )) ## LN parameter
264- ln_in_params = (ln_in_mu , ln_in_scale )
265+ ln_in_mu2 = jnp .zeros ((1 , input_dim )) ## LN parameter
266+ ln_in_scale2 = jnp .ones ((1 , input_dim )) ## LN parameter
267+ ln_in_params = (ln_in_mu , ln_in_scale , ln_in_mu2 , ln_in_scale2 )
265268 self .probe_params = (learnable_query , * cross_attn_params , * self_attn_params , * mlp_params , * ln_in_params )
266269
267270 ## set up gradient calculator
0 commit comments