Skip to content

Commit 247de74

Browse files
author
Alexander Ororbia
committed
cleaned up probes/docs for probes
1 parent 9d7acbb commit 247de74

File tree

3 files changed

+132
-3
lines changed

3 files changed

+132
-3
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,15 @@ def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array
3232
3333
Args:
3434
params (tuple): tuple of parameters
35+
3536
x1 (jax.Array): query sequence. Shape: (B, T, Dq)
37+
3638
x2 (jax.Array): key-value sequence. Shape: (B, S, Dkv)
39+
3740
mask (jax.Array): mask tensor. Shape: (B, T, S)
41+
3842
n_heads (int, optional): number of attention heads. Defaults to 8.
43+
3944
dropout_rate (float, optional): dropout rate. Defaults to 0.0.
4045
4146
Returns:
@@ -70,6 +75,27 @@ def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array
7075

7176
@bind(jax.jit, static_argnums=[3, 4, 5, 6])
7277
def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_softmax=True):
78+
"""
79+
Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).
80+
81+
Args:
82+
params: parameters tuple/list of probe
83+
84+
encodings: input encoding vectors/data
85+
86+
mask: optional mask to be applied to internal cross-attention
87+
88+
n_heads: number of attention heads
89+
90+
dropout: if >0, triggers drop-out applied internally to cross-attention
91+
92+
use_LN: use layer normalization?
93+
94+
use_softmax: should softmax be applied to output of attention probe? (useful for classification)
95+
96+
Returns:
97+
output scores/probabilities, cross-attention (hidden) features
98+
"""
7399
# encoded_image_feature: (B, hw, dim)
74100
#learnable_query, *_params) = params
75101
learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout, Whid, bhid, Wln_mu, Wln_scale, Wy, by = params
@@ -87,6 +113,30 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
87113

88114
@bind(jax.jit, static_argnums=[4, 5, 6, 7])
89115
def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_softmax=True):
116+
"""
117+
Runs and evaluates the nonlinear attentive probe given a paired set of encoding vectors and externally assigned
118+
labels/regression targets.
119+
120+
Args:
121+
params: parameters tuple/list of probe
122+
123+
encodings: input encoding vectors/data
124+
125+
labels: output target values (e.g., labels, regression target vectors)
126+
127+
mask: optional mask to be applied to internal cross-attention
128+
129+
n_heads: number of attention heads
130+
131+
dropout: if >0, triggers drop-out applied internally to cross-attention
132+
133+
use_LN: use layer normalization?
134+
135+
use_softmax: should softmax be applied to output of attention probe? (useful for classification)
136+
137+
Returns:
138+
current loss value, output scores/probabilities
139+
"""
90140
# encodings: (B, hw, dim)
91141
outs, _ = run_attention_probe(params, encodings, mask, n_heads, dropout, use_LN, use_softmax)
92142
if use_softmax: ## Multinoulli log likelihood for 1-of-K predictions
@@ -97,6 +147,10 @@ def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout:
97147

98148
class AttentiveProbe(Probe):
99149
"""
150+
This implements a nonlinear attentive probe, which is useful for evaluating the quality of
151+
encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot
152+
encodings or real-valued vector regression targets).
153+
100154
Args:
101155
dkey: init seed key
102156
@@ -167,13 +221,34 @@ def __init__(
167221
self.eta = 0.001
168222

169223
def process(self, embedding_sequence):
224+
"""
225+
Runs the probe's inference scheme given an input batch of sequences of encodings/embeddings.
226+
227+
Args:
228+
embedding_sequence: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
229+
230+
Returns:
231+
probe output scores/probability values
232+
"""
170233
outs, feats = run_attention_probe(
171234
self.probe_params, embedding_sequence, self.mask, self.num_heads, 0.0, use_LN=self.use_LN,
172235
use_softmax=self.use_softmax
173236
)
174237
return outs
175238

176239
def update(self, embedding_sequence, labels):
240+
"""
241+
Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally
242+
assigned labels/target vector values.
243+
244+
Args:
245+
embedding_sequence: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
246+
247+
labels: target values that map to embedding sequence; shape (B, target_value_dim)
248+
249+
Returns:
250+
probe output scores/probability values
251+
"""
177252
## compute partial derivatives / adjustments to probe parameters
178253
outputs, grads = self.grad_fx(
179254
self.probe_params, embedding_sequence, labels, self.mask, self.num_heads, dropout=0., use_LN=self.use_LN,

ngclearn/utils/analysis/linear_probe.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def eval_linear_probe(params, x, y, use_softmax=True, use_LN=False):
5151

5252
class LinearProbe(Probe):
5353
"""
54+
This implements a regularized linear probe, which is useful for evaluating the quality of
55+
encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot
56+
encodings or real-valued vector regression targets).
57+
Note that this probe allows for configurable Elastic-net (L1+L2) regularization.
58+
5459
Args:
5560
dkey: init seed key
5661
@@ -79,7 +84,6 @@ def __init__(
7984
self.use_LN = use_LN
8085
self.l2_decay = 0.0001
8186
self.l1_decay = 0.000025
82-
## TODO: add in pre-built layer norm of inputs?
8387

8488
## set up classifier
8589
flat_input_dim = input_dim * source_seq_length
@@ -97,14 +101,35 @@ def __init__(
97101
self.eta = 0.001
98102

99103
def process(self, embeddings):
104+
"""
105+
Runs the probe's inference scheme given an input batch of sequences of encodings/embeddings.
106+
107+
Args:
108+
embedding_sequence: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
109+
110+
Returns:
111+
probe output scores/probability values
112+
"""
100113
_embeddings = embeddings
101-
if len(_embeddings.shape) > 2:
114+
if len(_embeddings.shape) > 2: ## we flatten a sequence batch to 2D for a linear probe
102115
flat_dim = embeddings.shape[1] * embeddings.shape[2]
103116
_embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
104117
outs = run_linear_probe(self.probe_params, _embeddings, use_softmax=self.use_softmax, use_LN=self.use_LN)
105118
return outs
106119

107120
def update(self, embeddings, labels):
121+
"""
122+
Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally
123+
assigned labels/target vector values.
124+
125+
Args:
126+
embedding_sequence: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
127+
128+
labels: target values that map to embedding sequence; shape (B, target_value_dim)
129+
130+
Returns:
131+
probe output scores/probability values
132+
"""
108133
_embeddings = embeddings
109134
if len(_embeddings.shape) > 2:
110135
flat_dim = embeddings.shape[1] * embeddings.shape[2]
@@ -123,3 +148,4 @@ def update(self, embeddings, labels):
123148
self.optim_params, self.probe_params, grads, eta=self.eta
124149
)
125150
return loss, predictions
151+

ngclearn/utils/analysis/probe.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ def update(self, embeddings, labels):
2626
return L, predictions
2727

2828
def predict(self, data):
29+
"""
30+
Runs this probe's inference scheme over a pool of data.
31+
32+
Args:
33+
data: a dataset or design tensor/matrix containing encoding vector sequences; shape (N, T, embed_dim) or (N, embed_dim)
34+
35+
Returns:
36+
the output scores/predictions made by this probe
37+
"""
2938
_data = data
3039
if len(_data.shape) < 3:
3140
_data = jnp.expand_dims(_data, axis=1)
@@ -45,13 +54,31 @@ def predict(self, data):
4554
return Y_mu
4655

4756
def fit(self, data, labels, n_iter=50):
57+
"""
58+
Fits this probe to a pool of data.
59+
60+
Args:
61+
data: a dataset or design tensor/matrix containing encoding vector sequences; shape (N, T, embed_dim) or (N, embed_dim)
62+
63+
labels: a design matrix containing corresponding labels/targets for the embedding data; shape (N, target_dim)
64+
65+
Returns:
66+
the output scores/predictions made by this probe
67+
"""
4868
_data = data
4969
if len(_data.shape) < 3:
5070
_data = jnp.expand_dims(_data, axis=1)
5171

5272
n_samples, seq_len, dim = _data.shape
73+
size_modulo = n_samples % self.batch_size
74+
if size_modulo > 0:
75+
## we append some dup data for dataset design tensors that do not divide by batch size evenly
76+
_chunk = _data[0:size_modulo, :, :]
77+
_data = jnp.concatenate((_data, _chunk), axis=0)
78+
n_samples, seq_len, dim = _data.shape
5379
n_batches = int(n_samples / self.batch_size)
5480

81+
## run main probe fitting loop
5582
Y_mu = []
5683
_Y = None
5784
for iter in range(n_iter):
@@ -81,4 +108,5 @@ def fit(self, data, labels, n_iter=50):
81108
print()
82109
if iter == n_iter - 1:
83110
Y_mu = jnp.concatenate(Y_mu, axis=0)
84-
return Y_mu, _Y
111+
return Y_mu, _Y ## return predictions mapped to current shuffling of labels
112+

0 commit comments

Comments
 (0)