@@ -73,8 +73,8 @@ def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array
7373 attention = attention .transpose ([0 , 2 , 1 , 3 ]).reshape ((B , T , - 1 )) # (B, T, H, E) => (B, T, D)
7474 return attention @ Wout + bout # (B, T, Dq)
7575
76- @bind (jax .jit , static_argnums = [3 , 4 , 5 , 6 ])
77- def run_attention_probe (params , encodings , mask , n_heads : int , dropout : float = 0.0 , use_LN = False , use_softmax = True ):
76+ @bind (jax .jit , static_argnums = [3 , 4 , 5 , 6 , 7 ])
77+ def run_attention_probe (params , encodings , mask , n_heads : int , dropout : float = 0.0 , use_LN = False , use_LN_input = True , use_softmax = True ):
7878 """
7979 Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).
8080
@@ -101,8 +101,11 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
101101 learnable_query , Wq , bq , Wk , bk , Wv , bv , Wout , bout ,\
102102 Wqs , bqs , Wks , bks , Wvs , bvs , Wouts , bouts , Wlnattn_mu ,\
103103 Wlnattn_scale , Whid1 , bhid1 , Wln_mu1 , Wln_scale1 , Whid2 ,\
104- bhid2 , Wln_mu2 , Wln_scale2 , Whid3 , bhid3 , Wln_mu3 , Wln_scale3 , Wy , by = params
104+ bhid2 , Wln_mu2 , Wln_scale2 , Whid3 , bhid3 , Wln_mu3 , Wln_scale3 ,\
105+ Wy , by , ln_in_mu , ln_in_scale = params
105106 cross_attn_params = (Wq , bq , Wk , bk , Wv , bv , Wout , bout )
107+ if use_LN_input :
108+ learnable_query = layer_normalize (learnable_query , ln_in_mu , ln_in_scale )
106109 features = cross_attention (cross_attn_params , learnable_query , encodings , mask , n_heads , dropout )
107110 # Perform a single self-attention block here
108111 # Self-Attention
@@ -200,7 +203,7 @@ class AttentiveProbe(Probe):
200203 """
201204 def __init__ (
202205 self , dkey , source_seq_length , input_dim , out_dim , num_heads = 8 , attn_dim = 64 ,
203- target_seq_length = 1 , learnable_query_dim = 32 , batch_size = 1 , hid_dim = 32 , use_LN = True , use_softmax = True , ** kwargs
206+ target_seq_length = 1 , learnable_query_dim = 32 , batch_size = 1 , hid_dim = 32 , use_LN = True , use_LN_input = True , use_softmax = True , ** kwargs
204207 ):
205208 super ().__init__ (dkey , batch_size , ** kwargs )
206209 assert attn_dim % num_heads == 0 , f"`attn_dim` must be divisible by `num_heads`. Got { attn_dim } and { num_heads } ."
@@ -212,6 +215,7 @@ def __init__(
212215 self .out_dim = out_dim
213216 self .use_softmax = use_softmax
214217 self .use_LN = use_LN
218+ self .use_LN_input = use_LN_input
215219
216220 sigma = 0.05
217221 ## cross-attention parameters
@@ -254,7 +258,11 @@ def __init__(
254258 Wy = random .normal (subkeys [22 ], (learnable_query_dim , out_dim )) * sigma
255259 by = random .normal (subkeys [23 ], (1 , out_dim )) * sigma
256260 mlp_params = (Whid1 , bhid1 , Wln_mu1 , Wln_scale1 , Whid2 , bhid2 , Wln_mu2 , Wln_scale2 , Whid3 , bhid3 , Wln_mu3 , Wln_scale3 , Wy , by )
257- self .probe_params = (learnable_query , * cross_attn_params , * self_attn_params , * mlp_params )
261+ # Finally, define ln for the input to the attention
262+ ln_in_mu = jnp .zeros ((1 , learnable_query_dim )) ## LN parameter
263+ ln_in_scale = jnp .ones ((1 , learnable_query_dim )) ## LN parameter
264+ ln_in_params = (ln_in_mu , ln_in_scale )
265+ self .probe_params = (learnable_query , * cross_attn_params , * self_attn_params , * mlp_params , * ln_in_params )
258266
259267 ## set up gradient calculator
260268 self .grad_fx = jax .value_and_grad (eval_attention_probe , argnums = 0 , has_aux = True )
@@ -294,8 +302,9 @@ def update(self, embedding_sequence, labels, dkey=None):
294302 """
295303 # TODO: put in dkey to facilitate dropout
296304 ## compute partial derivatives / adjustments to probe parameters
305+ # NOTE: Viet: Change back to 0.0 for now for the code to run
297306 outputs , grads = self .grad_fx (
298- self .probe_params , embedding_sequence , labels , self .mask , self .num_heads , dropout = 0.5 , use_LN = self .use_LN ,
307+ self .probe_params , embedding_sequence , labels , self .mask , self .num_heads , dropout = 0.0 , use_LN = self .use_LN ,
299308 use_softmax = self .use_softmax
300309 )
301310 loss , predictions = outputs
0 commit comments