Skip to content

Commit 84005b5

Browse files
author
Alexander Ororbia
committed
cleaned up probes
1 parent 27fd9bf commit 84005b5

File tree

3 files changed

+64
-65
lines changed

3 files changed

+64
-65
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def masked_fill(x: jax.Array, mask: jax.Array, value=0) -> jax.Array:
2121
"""
2222
return jnp.where(mask, jnp.broadcast_to(value, x.shape), x)
2323

24-
@bind(jax.jit, static_argnums=[4, 5])
25-
def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array, n_heads: int=8, dropout_rate: float=0.0) -> jax.Array:
24+
@bind(jax.jit, static_argnums=[5, 6])
25+
def cross_attention(dkey, params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array, n_heads: int=8, dropout_rate: float=0.0) -> jax.Array:
2626
"""
2727
Run cross-attention function given a list of parameters and two sequences (x1 and x2).
2828
The function takes in a query sequence x1 and a key-value sequence x2, and returns an output of the same shape as x1.
@@ -31,6 +31,8 @@ def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array
3131
H is the number of attention heads.
3232
3333
Args:
34+
dkey: JAX key to trigger any internal noise (drop-out)
35+
3436
params (tuple): tuple of parameters
3537
3638
x1 (jax.Array): query sequence. Shape: (B, T, Dq)
@@ -68,17 +70,22 @@ def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array
6870
score = jax.nn.softmax(score, axis=-1) # (B, H, T, S)
6971
score = score.astype(q.dtype) # (B, H, T, S)
7072
if dropout_rate > 0.:
71-
score = drop_out(input=score, rate=dropout_rate) ## NOTE: normally you apply dropout here
73+
score = drop_out(dkey, input=score, rate=dropout_rate) ## NOTE: normally you apply dropout here
7274
attention = jnp.einsum("BHTS,BHSE->BHTE", score, v) # (B, T, H, E)
7375
attention = attention.transpose([0, 2, 1, 3]).reshape((B, T, -1)) # (B, T, H, E) => (B, T, D)
7476
return attention @ Wout + bout # (B, T, Dq)
7577

76-
@bind(jax.jit, static_argnums=[3, 4, 5, 6, 7])
77-
def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False, use_softmax=True):
78+
@bind(jax.jit, static_argnums=[4, 5, 6, 7, 8])
79+
def run_attention_probe(
80+
dkey, params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False,
81+
use_softmax=True
82+
):
7883
"""
7984
Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).
8085
8186
Args:
87+
dkey: JAX key for any internal noise to be applied
88+
8289
params: parameters tuple/list of probe
8390
8491
encodings: input encoding vectors/data
@@ -91,6 +98,8 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
9198
9299
use_LN: use layer normalization?
93100
101+
use_LN_input: use layer normalization on input encodings?
102+
94103
use_softmax: should softmax be applied to output of attention probe? (useful for classification)
95104
96105
Returns:
@@ -107,7 +116,7 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
107116
if use_LN_input:
108117
learnable_query = layer_normalize(learnable_query, ln_in_mu, ln_in_scale)
109118
encodings = layer_normalize(encodings, ln_in_mu2, ln_in_scale2)
110-
features = cross_attention(cross_attn_params, learnable_query, encodings, mask, n_heads, dropout)
119+
features = cross_attention(dkey, cross_attn_params, learnable_query, encodings, mask, n_heads, dropout)
111120
# Perform a single self-attention block here
112121
# Self-Attention
113122
self_attn_params = (Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts)
@@ -138,13 +147,15 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
138147
outs = jax.nn.softmax(outs)
139148
return outs, features
140149

141-
@bind(jax.jit, static_argnums=[4, 5, 6, 7, 8])
142-
def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False, use_softmax=True):
150+
@bind(jax.jit, static_argnums=[5, 6, 7, 8, 9])
151+
def eval_attention_probe(dkey, params, encodings, labels, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False, use_softmax=True):
143152
"""
144153
Runs and evaluates the nonlinear attentive probe given a paired set of encoding vectors and externally assigned
145154
labels/regression targets.
146155
147156
Args:
157+
dkey: JAX key to trigger any internal noise (as in drop-out)
158+
148159
params: parameters tuple/list of probe
149160
150161
encodings: input encoding vectors/data
@@ -165,7 +176,7 @@ def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout:
165176
current loss value, output scores/probabilities
166177
"""
167178
# encodings: (B, hw, dim)
168-
outs, _ = run_attention_probe(params, encodings, mask, n_heads, dropout, use_LN, use_LN_input, use_softmax)
179+
outs, _ = run_attention_probe(dkey, params, encodings, mask, n_heads, dropout, use_LN, use_LN_input, use_softmax)
169180
if use_softmax: ## Multinoulli log likelihood for 1-of-K predictions
170181
L = -jnp.mean(jnp.sum(jnp.log(outs.clip(min=1e-5)) * labels, axis=1, keepdims=True))
171182
else: ## MSE for real-valued outputs
@@ -219,6 +230,7 @@ def __init__(
219230
self.use_softmax = use_softmax
220231
self.use_LN = use_LN
221232
self.use_LN_input = use_LN_input
233+
self.dropout = 0.5
222234

223235
sigma = 0.05
224236
## cross-attention parameters
@@ -275,42 +287,25 @@ def __init__(
275287
self.optim_params = adam.adam_init(self.probe_params)
276288
self.eta = 0.0002 #0.001
277289

278-
def process(self, embedding_sequence):
279-
"""
280-
Runs the probe's inference scheme given an input batch of sequences of encodings/embeddings.
281-
282-
Args:
283-
embedding_sequence: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
284-
285-
Returns:
286-
probe output scores/probability values
287-
"""
288-
#print(embedding_sequence.shape)
290+
def process(self, embeddings, dkey=None):
291+
noise_key = None
292+
if dkey is not None:
293+
dkey, *subkeys = random.split(dkey, 2)
294+
noise_key = subkeys[0]
289295
outs, feats = run_attention_probe(
290-
self.probe_params, embedding_sequence, self.dev_mask, self.num_heads, 0.0, use_LN=self.use_LN,
291-
use_LN_input=self.use_LN_input, use_softmax=self.use_softmax
296+
noise_key, self.probe_params, embeddings, self.dev_mask, self.num_heads, 0.0,
297+
use_LN=self.use_LN, use_LN_input=self.use_LN_input, use_softmax=self.use_softmax
292298
)
293299
return outs
294300

295-
def update(self, embedding_sequence, labels, dkey=None):
296-
"""
297-
Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally
298-
assigned labels/target vector values.
299-
300-
Args:
301-
embedding_sequence: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
302-
303-
labels: target values that map to embedding sequence; shape (B, target_value_dim)
304-
305-
Returns:
306-
probe output scores/probability values
307-
"""
308-
# TODO: put in dkey to facilitate dropout
309-
## compute partial derivatives / adjustments to probe parameters
310-
# NOTE: Viet: Change back to 0.0 for now for the code to run
301+
def update(self, embeddings, labels, dkey=None):
302+
noise_key = None
303+
if dkey is not None:
304+
dkey, *subkeys = random.split(dkey, 2)
305+
noise_key = subkeys[0]
311306
outputs, grads = self.grad_fx(
312-
self.probe_params, embedding_sequence, labels, self.mask, self.num_heads, dropout=0.0, use_LN=self.use_LN,
313-
use_LN_input=self.use_LN_input, use_softmax=self.use_softmax
307+
noise_key, self.probe_params, embeddings, labels, self.mask, self.num_heads, dropout=self.dropout,
308+
use_LN=self.use_LN, use_LN_input=self.use_LN_input, use_softmax=self.use_softmax
314309
)
315310
loss, predictions = outputs
316311
## adjust parameters of probe

ngclearn/utils/analysis/linear_probe.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -101,36 +101,15 @@ def __init__(
101101
self.optim_params = adam.adam_init(self.probe_params)
102102
self.eta = 0.001
103103

104-
def process(self, embeddings):
105-
"""
106-
Runs the probe's inference scheme given an input batch of sequences of encodings/embeddings.
107-
108-
Args:
109-
embedding_sequence: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
110-
111-
Returns:
112-
probe output scores/probability values
113-
"""
104+
def process(self, embeddings, dkey=None):
114105
_embeddings = embeddings
115106
if len(_embeddings.shape) > 2: ## we flatten a sequence batch to 2D for a linear probe
116107
flat_dim = embeddings.shape[1] * embeddings.shape[2]
117108
_embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
118109
outs = run_linear_probe(self.probe_params, _embeddings, use_softmax=self.use_softmax, use_LN=self.use_LN)
119110
return outs
120111

121-
def update(self, embeddings, labels):
122-
"""
123-
Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally
124-
assigned labels/target vector values.
125-
126-
Args:
127-
embedding_sequence: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
128-
129-
labels: target values that map to embedding sequence; shape (B, target_value_dim)
130-
131-
Returns:
132-
probe output scores/probability values
133-
"""
112+
def update(self, embeddings, labels, dkey=None):
134113
_embeddings = embeddings
135114
if len(_embeddings.shape) > 2:
136115
flat_dim = embeddings.shape[1] * embeddings.shape[2]

ngclearn/utils/analysis/probe.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,36 @@ def __init__(
1818
self.batch_size = batch_size
1919
self.dev_batch_size = dev_batch_size
2020

21-
def process(self, embeddings):
21+
def process(self, embeddings, dkey=None):
22+
"""
23+
Runs the probe's inference scheme given an input batch of sequences of encodings/embeddings.
24+
25+
Args:
26+
embeddings: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
27+
28+
dkey: Optional JAX noise key
29+
30+
Returns:
31+
probe output scores/probability values
32+
"""
2233
predictions = None
2334
return predictions
2435

25-
def update(self, embeddings, labels):
36+
def update(self, embeddings, labels, dkey=None):
37+
"""
38+
Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally
39+
assigned labels/target vector values.
40+
41+
Args:
42+
embeddings: a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
43+
44+
labels: target values that map to embedding sequence; shape (B, target_value_dim)
45+
46+
dkey: Optional JAX noise key
47+
48+
Returns:
49+
probe output scores/probability values
50+
"""
2651
L = predictions = None
2752
return L, predictions
2853

0 commit comments

Comments
 (0)