@@ -21,8 +21,8 @@ def masked_fill(x: jax.Array, mask: jax.Array, value=0) -> jax.Array:
2121 """
2222 return jnp .where (mask , jnp .broadcast_to (value , x .shape ), x )
2323
24- @bind (jax .jit , static_argnums = [4 , 5 ])
25- def cross_attention (params : tuple , x1 : jax .Array , x2 : jax .Array , mask : jax .Array , n_heads : int = 8 , dropout_rate : float = 0.0 ) -> jax .Array :
24+ @bind (jax .jit , static_argnums = [5 , 6 ])
25+ def cross_attention (dkey , params : tuple , x1 : jax .Array , x2 : jax .Array , mask : jax .Array , n_heads : int = 8 , dropout_rate : float = 0.0 ) -> jax .Array :
2626 """
2727 Run cross-attention function given a list of parameters and two sequences (x1 and x2).
2828 The function takes in a query sequence x1 and a key-value sequence x2, and returns an output of the same shape as x1.
@@ -31,6 +31,8 @@ def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array
3131 H is the number of attention heads.
3232
3333 Args:
34+ dkey: JAX key to trigger any internal noise (drop-out)
35+
3436 params (tuple): tuple of parameters
3537
3638 x1 (jax.Array): query sequence. Shape: (B, T, Dq)
@@ -68,17 +70,22 @@ def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array
6870 score = jax .nn .softmax (score , axis = - 1 ) # (B, H, T, S)
6971 score = score .astype (q .dtype ) # (B, H, T, S)
7072 if dropout_rate > 0. :
71- score = drop_out (input = score , rate = dropout_rate ) ## NOTE: normally you apply dropout here
73+ score = drop_out (dkey , input = score , rate = dropout_rate ) ## NOTE: normally you apply dropout here
7274 attention = jnp .einsum ("BHTS,BHSE->BHTE" , score , v ) # (B, T, H, E)
7375 attention = attention .transpose ([0 , 2 , 1 , 3 ]).reshape ((B , T , - 1 )) # (B, T, H, E) => (B, T, D)
7476 return attention @ Wout + bout # (B, T, Dq)
7577
76- @bind (jax .jit , static_argnums = [3 , 4 , 5 , 6 , 7 ])
77- def run_attention_probe (params , encodings , mask , n_heads : int , dropout : float = 0.0 , use_LN = False , use_LN_input = False , use_softmax = True ):
78+ @bind (jax .jit , static_argnums = [4 , 5 , 6 , 7 , 8 ])
79+ def run_attention_probe (
80+ dkey , params , encodings , mask , n_heads : int , dropout : float = 0.0 , use_LN = False , use_LN_input = False ,
81+ use_softmax = True
82+ ):
7883 """
7984 Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).
8085
8186 Args:
87+ dkey: JAX key for any internal noise to be applied
88+
8289 params: parameters tuple/list of probe
8390
8491 encodings: input encoding vectors/data
@@ -91,6 +98,8 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
9198
9299 use_LN: use layer normalization?
93100
101+ use_LN_input: use layer normalization on input encodings?
102+
94103 use_softmax: should softmax be applied to output of attention probe? (useful for classification)
95104
96105 Returns:
@@ -107,7 +116,7 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
107116 if use_LN_input :
108117 learnable_query = layer_normalize (learnable_query , ln_in_mu , ln_in_scale )
109118 encodings = layer_normalize (encodings , ln_in_mu2 , ln_in_scale2 )
110- features = cross_attention (cross_attn_params , learnable_query , encodings , mask , n_heads , dropout )
119+ features = cross_attention (dkey , cross_attn_params , learnable_query , encodings , mask , n_heads , dropout )
111120 # Perform a single self-attention block here
112121 # Self-Attention
113122 self_attn_params = (Wqs , bqs , Wks , bks , Wvs , bvs , Wouts , bouts )
@@ -138,13 +147,15 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
138147 outs = jax .nn .softmax (outs )
139148 return outs , features
140149
141- @bind (jax .jit , static_argnums = [4 , 5 , 6 , 7 , 8 ])
142- def eval_attention_probe (params , encodings , labels , mask , n_heads : int , dropout : float = 0.0 , use_LN = False , use_LN_input = False , use_softmax = True ):
150+ @bind (jax .jit , static_argnums = [5 , 6 , 7 , 8 , 9 ])
151+ def eval_attention_probe (dkey , params , encodings , labels , mask , n_heads : int , dropout : float = 0.0 , use_LN = False , use_LN_input = False , use_softmax = True ):
143152 """
144153 Runs and evaluates the nonlinear attentive probe given a paired set of encoding vectors and externally assigned
145154 labels/regression targets.
146155
147156 Args:
157+ dkey: JAX key to trigger any internal noise (as in drop-out)
158+
148159 params: parameters tuple/list of probe
149160
150161 encodings: input encoding vectors/data
@@ -165,7 +176,7 @@ def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout:
165176 current loss value, output scores/probabilities
166177 """
167178 # encodings: (B, hw, dim)
168- outs , _ = run_attention_probe (params , encodings , mask , n_heads , dropout , use_LN , use_LN_input , use_softmax )
179+ outs , _ = run_attention_probe (dkey , params , encodings , mask , n_heads , dropout , use_LN , use_LN_input , use_softmax )
169180 if use_softmax : ## Multinoulli log likelihood for 1-of-K predictions
170181 L = - jnp .mean (jnp .sum (jnp .log (outs .clip (min = 1e-5 )) * labels , axis = 1 , keepdims = True ))
171182 else : ## MSE for real-valued outputs
@@ -219,6 +230,7 @@ def __init__(
219230 self .use_softmax = use_softmax
220231 self .use_LN = use_LN
221232 self .use_LN_input = use_LN_input
233+ self .dropout = 0.5
222234
223235 sigma = 0.05
224236 ## cross-attention parameters
@@ -275,42 +287,25 @@ def __init__(
275287 self .optim_params = adam .adam_init (self .probe_params )
276288 self .eta = 0.0002 #0.001
277289
278- def process (self , embedding_sequence ):
279- """
280- Runs the probe's inference scheme given an input batch of sequences of encodings/embeddings.
281-
282- Args:
283- embedding_sequence: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
284-
285- Returns:
286- probe output scores/probability values
287- """
288- #print(embedding_sequence.shape)
290+ def process (self , embeddings , dkey = None ):
291+ noise_key = None
292+ if dkey is not None :
293+ dkey , * subkeys = random .split (dkey , 2 )
294+ noise_key = subkeys [0 ]
289295 outs , feats = run_attention_probe (
290- self .probe_params , embedding_sequence , self .dev_mask , self .num_heads , 0.0 , use_LN = self . use_LN ,
291- use_LN_input = self .use_LN_input , use_softmax = self .use_softmax
296+ noise_key , self .probe_params , embeddings , self .dev_mask , self .num_heads , 0.0 ,
297+ use_LN = self . use_LN , use_LN_input = self .use_LN_input , use_softmax = self .use_softmax
292298 )
293299 return outs
294300
295- def update (self , embedding_sequence , labels , dkey = None ):
296- """
297- Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally
298- assigned labels/target vector values.
299-
300- Args:
301- embedding_sequence: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
302-
303- labels: target values that map to embedding sequence; shape (B, target_value_dim)
304-
305- Returns:
306- probe output scores/probability values
307- """
308- # TODO: put in dkey to facilitate dropout
309- ## compute partial derivatives / adjustments to probe parameters
310- # NOTE: Viet: Change back to 0.0 for now for the code to run
301+ def update (self , embeddings , labels , dkey = None ):
302+ noise_key = None
303+ if dkey is not None :
304+ dkey , * subkeys = random .split (dkey , 2 )
305+ noise_key = subkeys [0 ]
311306 outputs , grads = self .grad_fx (
312- self .probe_params , embedding_sequence , labels , self .mask , self .num_heads , dropout = 0.0 , use_LN = self .use_LN ,
313- use_LN_input = self .use_LN_input , use_softmax = self .use_softmax
307+ noise_key , self .probe_params , embeddings , labels , self .mask , self .num_heads , dropout = self .dropout ,
308+ use_LN = self . use_LN , use_LN_input = self .use_LN_input , use_softmax = self .use_softmax
314309 )
315310 loss , predictions = outputs
316311 ## adjust parameters of probe
0 commit comments