@@ -74,7 +74,7 @@ def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array
7474 return attention @ Wout + bout # (B, T, Dq)
7575
7676@bind (jax .jit , static_argnums = [3 , 4 , 5 , 6 , 7 ])
77- def run_attention_probe (params , encodings , mask , n_heads : int , dropout : float = 0.0 , use_LN = False , use_LN_input = True , use_softmax = True ):
77+ def run_attention_probe (params , encodings , mask , n_heads : int , dropout : float = 0.0 , use_LN = False , use_LN_input = False , use_softmax = True ):
7878 """
7979 Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).
8080
@@ -138,8 +138,8 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
138138 outs = jax .nn .softmax (outs )
139139 return outs , features
140140
141- @bind (jax .jit , static_argnums = [4 , 5 , 6 , 7 ])
142- def eval_attention_probe (params , encodings , labels , mask , n_heads : int , dropout : float = 0.0 , use_LN = False , use_softmax = True ):
141+ @bind (jax .jit , static_argnums = [4 , 5 , 6 , 7 , 8 ])
142+ def eval_attention_probe (params , encodings , labels , mask , n_heads : int , dropout : float = 0.0 , use_LN = False , use_LN_input = False , use_softmax = True ):
143143 """
144144 Runs and evaluates the nonlinear attentive probe given a paired set of encoding vectors and externally assigned
145145 labels/regression targets.
@@ -165,7 +165,7 @@ def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout:
165165 current loss value, output scores/probabilities
166166 """
167167 # encodings: (B, hw, dim)
168- outs , _ = run_attention_probe (params , encodings , mask , n_heads , dropout , use_LN , use_softmax )
168+ outs , _ = run_attention_probe (params , encodings , mask , n_heads , dropout , use_LN , use_LN_input , use_softmax )
169169 if use_softmax : ## Multinoulli log likelihood for 1-of-K predictions
170170 L = - jnp .mean (jnp .sum (jnp .log (outs .clip (min = 1e-5 )) * labels , axis = 1 , keepdims = True ))
171171 else : ## MSE for real-valued outputs
@@ -206,7 +206,7 @@ class AttentiveProbe(Probe):
206206 """
207207 def __init__ (
208208 self , dkey , source_seq_length , input_dim , out_dim , num_heads = 8 , attn_dim = 64 ,
209- target_seq_length = 1 , learnable_query_dim = 32 , batch_size = 1 , hid_dim = 32 , use_LN = True , use_LN_input = True , use_softmax = True , ** kwargs
209+ target_seq_length = 1 , learnable_query_dim = 32 , batch_size = 1 , hid_dim = 32 , use_LN = True , use_LN_input = False , use_softmax = True , ** kwargs
210210 ):
211211 super ().__init__ (dkey , batch_size , ** kwargs )
212212 assert attn_dim % num_heads == 0 , f"`attn_dim` must be divisible by `num_heads`. Got { attn_dim } and { num_heads } ."
@@ -288,7 +288,7 @@ def process(self, embedding_sequence):
288288 #print(embedding_sequence.shape)
289289 outs , feats = run_attention_probe (
290290 self .probe_params , embedding_sequence , self .dev_mask , self .num_heads , 0.0 , use_LN = self .use_LN ,
291- use_softmax = self .use_softmax
291+ use_LN_input = self . use_LN_input , use_softmax = self .use_softmax
292292 )
293293 return outs
294294
@@ -310,7 +310,7 @@ def update(self, embedding_sequence, labels, dkey=None):
310310 # NOTE: Viet: Change back to 0.0 for now for the code to run
311311 outputs , grads = self .grad_fx (
312312 self .probe_params , embedding_sequence , labels , self .mask , self .num_heads , dropout = 0.0 , use_LN = self .use_LN ,
313- use_softmax = self .use_softmax
313+ use_LN_input = self . use_LN_input , use_softmax = self .use_softmax
314314 )
315315 loss , predictions = outputs
316316 ## adjust parameters of probe
0 commit comments