@@ -32,10 +32,15 @@ def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array
3232
3333 Args:
3434 params (tuple): tuple of parameters
35+
3536 x1 (jax.Array): query sequence. Shape: (B, T, Dq)
37+
3638 x2 (jax.Array): key-value sequence. Shape: (B, S, Dkv)
39+
3740 mask (jax.Array): mask tensor. Shape: (B, T, S)
41+
3842 n_heads (int, optional): number of attention heads. Defaults to 8.
43+
3944 dropout_rate (float, optional): dropout rate. Defaults to 0.0.
4045
4146 Returns:
@@ -70,6 +75,27 @@ def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array
7075
7176@bind (jax .jit , static_argnums = [3 , 4 , 5 , 6 ])
7277def run_attention_probe (params , encodings , mask , n_heads : int , dropout : float = 0.0 , use_LN = False , use_softmax = True ):
78+ """
79+ Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).
80+
81+ Args:
82+ params: parameters tuple/list of probe
83+
84+ encodings: input encoding vectors/data
85+
86+ mask: optional mask to be applied to internal cross-attention
87+
88+ n_heads: number of attention heads
89+
90+ dropout: if >0, triggers drop-out applied internally to cross-attention
91+
92+ use_LN: use layer normalization?
93+
94+ use_softmax: should softmax be applied to output of attention probe? (useful for classification)
95+
96+ Returns:
97+ output scores/probabilities, cross-attention (hidden) features
98+ """
7399 # encoded_image_feature: (B, hw, dim)
74100 #learnable_query, *_params) = params
75101 learnable_query , Wq , bq , Wk , bk , Wv , bv , Wout , bout , Whid , bhid , Wln_mu , Wln_scale , Wy , by = params
@@ -87,6 +113,30 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
87113
88114@bind (jax .jit , static_argnums = [4 , 5 , 6 , 7 ])
89115def eval_attention_probe (params , encodings , labels , mask , n_heads : int , dropout : float = 0.0 , use_LN = False , use_softmax = True ):
116+ """
117+ Runs and evaluates the nonlinear attentive probe given a paired set of encoding vectors and externally assigned
118+ labels/regression targets.
119+
120+ Args:
121+ params: parameters tuple/list of probe
122+
123+ encodings: input encoding vectors/data
124+
125+ labels: output target values (e.g., labels, regression target vectors)
126+
127+ mask: optional mask to be applied to internal cross-attention
128+
129+ n_heads: number of attention heads
130+
131+ dropout: if >0, triggers drop-out applied internally to cross-attention
132+
133+ use_LN: use layer normalization?
134+
135+ use_softmax: should softmax be applied to output of attention probe? (useful for classification)
136+
137+ Returns:
138+ current loss value, output scores/probabilities
139+ """
90140 # encodings: (B, hw, dim)
91141 outs , _ = run_attention_probe (params , encodings , mask , n_heads , dropout , use_LN , use_softmax )
92142 if use_softmax : ## Multinoulli log likelihood for 1-of-K predictions
@@ -97,6 +147,10 @@ def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout:
97147
98148class AttentiveProbe (Probe ):
99149 """
150+ This implements a nonlinear attentive probe, which is useful for evaluating the quality of
151+ encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot
152+ encodings or real-valued vector regression targets).
153+
100154 Args:
101155 dkey: init seed key
102156
@@ -167,13 +221,34 @@ def __init__(
167221 self .eta = 0.001
168222
169223 def process (self , embedding_sequence ):
224+ """
225+ Runs the probe's inference scheme given an input batch of sequences of encodings/embeddings.
226+
227+ Args:
228+ embedding_sequence: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
229+
230+ Returns:
231+ probe output scores/probability values
232+ """
170233 outs , feats = run_attention_probe (
171234 self .probe_params , embedding_sequence , self .mask , self .num_heads , 0.0 , use_LN = self .use_LN ,
172235 use_softmax = self .use_softmax
173236 )
174237 return outs
175238
176239 def update (self , embedding_sequence , labels ):
240+ """
241+ Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally
242+ assigned labels/target vector values.
243+
244+ Args:
245+ embedding_sequence: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
246+
247+ labels: target values that map to embedding sequence; shape (B, target_value_dim)
248+
249+ Returns:
250+ probe output scores/probability values
251+ """
177252 ## compute partial derivatives / adjustments to probe parameters
178253 outputs , grads = self .grad_fx (
179254 self .probe_params , embedding_sequence , labels , self .mask , self .num_heads , dropout = 0. , use_LN = self .use_LN ,
0 commit comments