Skip to content

Commit 796178d

Browse files
author
Alexander Ororbia
committed
commit probes/mods to utils to analysis_tools branch
1 parent 35eae76 commit 796178d

File tree

5 files changed

+423
-1
lines changed

5 files changed

+423
-1
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## point to supported analysis probes
2+
from .linear_probe import LinearProbe
3+
from .attentive_probe import AttentiveProbe
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import jax
2+
import numpy as np
3+
from ngclearn.utils.analysis.probe import Probe
4+
from ngclearn.utils.model_utils import drop_out, softmax, gelu, layer_normalize
5+
from ngclearn.utils.optim import adam
6+
from jax import jit, random, numpy as jnp, lax, nn
7+
from functools import partial as bind
8+
9+
def masked_fill(x: jax.Array, mask: jax.Array, value=0) -> jax.Array:
10+
"""
11+
Return an output with masked condition, with non-masked value
12+
be the other value
13+
14+
Args:
15+
x (jax.Array): _description_
16+
mask (jax.Array): _description_
17+
value (int, optional): _description_. Defaults to 0.
18+
19+
Returns:
20+
jax.Array: _description_
21+
"""
22+
return jnp.where(mask, jnp.broadcast_to(value, x.shape), x)
23+
24+
@bind(jax.jit, static_argnums=[4, 5])
25+
def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array, n_heads: int=8, dropout_rate: float=0.0):
26+
B, T, Dq = x1.shape # The original shape
27+
_, S, Dkv = x2.shape
28+
# in here we attend x2 to x1
29+
Wq, bq, Wk, bk, Wv, bv, Wout, bout = params
30+
# projection
31+
q = x1 @ Wq + bq # normal linear transformation (B, T, D)
32+
k = x2 @ Wk + bk # normal linear transformation (B, S, D)
33+
v = x2 @ Wv + bv # normal linear transformation (B, S, D)
34+
hidden = q.shape[-1]
35+
_hidden = hidden // n_heads
36+
q = q.reshape((B, T, n_heads, _hidden)).transpose([0, 2, 1, 3]) # (B, H, T, D)
37+
k = k.reshape((B, S, n_heads, _hidden)).transpose([0, 2, 1, 3]) # (B, H, T, D)
38+
v = v.reshape((B, S, n_heads, _hidden)).transpose([0, 2, 1, 3]) # (B, H, T, D)
39+
score = jnp.einsum("BHTE,BHSE->BHTS", q, k) / jnp.sqrt(_hidden) # Q @ KT / ||d||; d = D // n_heads
40+
if mask is not None:
41+
Tq, Tk = q.shape[2], k.shape[2]
42+
assert mask.shape == (B, Tq, Tk), (mask.shape, (B, Tq, Tk))
43+
_mask = mask.reshape((B, 1, Tq, Tk)) # 'b tq tk -> b 1 tq tk'
44+
score = masked_fill(score, _mask, value=-jnp.inf) # basically masking out all must-unattended values
45+
score = jax.nn.softmax(score, axis=-1) # (B, H, T, S)
46+
score = score.astype(q.dtype) # (B, H, T, S)
47+
if dropout_rate > 0.:
48+
score = drop_out(input=score, rate=dropout_rate) ## NOTE: normally you apply dropout here
49+
attention = jnp.einsum("BHTS,BHSE->BHTE", score, v) # (B, T, H, E)
50+
attention = attention.transpose([0, 2, 1, 3]).reshape((B, T, -1)) # (B, T, H, E) => (B, T, D)
51+
return attention @ Wout + bout # (B, T, Dq)
52+
53+
@bind(jax.jit, static_argnums=[3, 4, 5, 6])
54+
def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_softmax=True):
55+
# encoded_image_feature: (B, hw, dim)
56+
#learnable_query, *_params) = params
57+
learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout, Whid, bhid, Wln_mu, Wln_scale, Wy, by = params
58+
attn_params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
59+
features = cross_attention(attn_params, learnable_query, encodings, mask, n_heads, dropout)
60+
features = features[:, 0] # (B, 1, dim) => (B, dim)
61+
hids = jnp.matmul((features + learnable_query[:, 0]), Whid) + bhid
62+
hids = gelu(hids)
63+
if use_LN: ## normalize hidden layer output of probe predictor
64+
hids = layer_normalize(hids, Wln_mu, Wln_scale)
65+
outs = jnp.matmul(hids, Wy) + by
66+
if use_softmax: ## apply softmax output nonlinearity
67+
outs = softmax(outs)
68+
return outs, features
69+
70+
@bind(jax.jit, static_argnums=[4, 5, 6, 7])
71+
def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_softmax=True):
72+
# encodings: (B, hw, dim)
73+
outs, _ = run_attention_probe(params, encodings, mask, n_heads, dropout, use_LN, use_softmax)
74+
if use_softmax: ## Multinoulli log likelihood for 1-of-K predictions
75+
L = -jnp.mean(jnp.sum(jnp.log(outs) * labels, axis=1, keepdims=True))
76+
else: ## MSE for real-valued outputs
77+
L = jnp.mean(jnp.sum(jnp.square(outs - labels), axis=1, keepdims=True))
78+
return L, outs #, features
79+
80+
class AttentiveProbe(Probe):
81+
"""
82+
Args:
83+
dkey: init seed key
84+
85+
source_seq_length: length of input sequence (e.g., height x width of the image feature)
86+
87+
input_dim: input dimensionality of probe
88+
89+
out_dim: output dimensionality of probe
90+
91+
num_heads: number of cross-attention heads
92+
93+
head_dim: output dimensionality of each cross-attention head
94+
95+
target_seq_length: to pool, we set it at one (or map the source sequence to the target sequence of length 1)
96+
97+
learnable_query_dim: target sequence dim (output dimension of cross-attention portion of probe)
98+
99+
batch_size: size of batches to process per internal call to update (or process)
100+
101+
hid_dim: dimensionality of hidden layer(s) of MLP portion of probe
102+
103+
use_LN: should layer normalization be used within MLP portions of probe or not?
104+
105+
use_softmax: should a softmax be applied to output of probe or not?
106+
107+
"""
108+
def __init__(
109+
self, dkey, source_seq_length, input_dim, out_dim, num_heads=8, head_dim=64,
110+
target_seq_length=1, learnable_query_dim=31, batch_size=1, hid_dim=32, use_LN=True, use_softmax=True, **kwargs
111+
):
112+
super().__init__(dkey, batch_size, **kwargs)
113+
self.dkey, *subkeys = random.split(self.dkey, 12)
114+
self.num_heads = num_heads
115+
self.source_seq_length = source_seq_length
116+
self.input_dim = input_dim
117+
self.out_dim = out_dim
118+
self.use_softmax = use_softmax
119+
self.use_LN = use_LN
120+
121+
sigma = 0.05
122+
## cross-attention parameters
123+
Wq = random.normal(subkeys[0], (learnable_query_dim, head_dim)) * sigma
124+
bq = random.normal(subkeys[1], (1, head_dim)) * sigma
125+
Wk = random.normal(subkeys[2], (input_dim, head_dim)) * sigma
126+
bk = random.normal(subkeys[3], (1, head_dim)) * sigma
127+
Wv = random.normal(subkeys[4], (input_dim, head_dim)) * sigma
128+
bv = random.normal(subkeys[5], (1, head_dim)) * sigma
129+
Wout = random.normal(subkeys[6], (head_dim, learnable_query_dim)) * sigma
130+
bout = random.normal(subkeys[7], (1, learnable_query_dim)) * sigma
131+
#params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
132+
learnable_query = jnp.zeros((batch_size, 1, learnable_query_dim)) # (B, T, D)
133+
#self.all_params = (learnable_query, *params)
134+
self.mask = np.zeros((batch_size, target_seq_length, source_seq_length)).astype(bool) ## mask tensor
135+
## MLP parameters
136+
Whid = random.normal(subkeys[8], (learnable_query_dim, hid_dim)) * sigma
137+
bhid = random.normal(subkeys[9], (1, hid_dim)) * sigma
138+
Wln_mu = jnp.zeros((1, hid_dim))
139+
Wln_scale = jnp.ones((1, hid_dim))
140+
Wy = random.normal(subkeys[8], (hid_dim, out_dim)) * sigma
141+
by = random.normal(subkeys[9], (1, out_dim)) * sigma
142+
#mlp_params = (Whid, bhid, Wln_mu, Wln_scale, Wy, by)
143+
self.probe_params = (learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout, Whid, bhid, Wln_mu, Wln_scale, Wy, by)
144+
145+
## set up gradient calculator
146+
self.grad_fx = jax.value_and_grad(eval_attention_probe, argnums=0, has_aux=True)
147+
## set up update rule/optimizer
148+
self.optim_params = adam.adam_init(self.probe_params)
149+
self.eta = 0.001
150+
151+
def process(self, embedding_sequence):
152+
outs, feats = run_attention_probe(
153+
self.probe_params, embedding_sequence, self.mask, self.num_heads, 0.0, use_LN=self.use_LN,
154+
use_softmax=self.use_softmax
155+
)
156+
return outs
157+
158+
def update(self, embedding_sequence, labels):
159+
## compute partial derivatives / adjustments to probe parameters
160+
outputs, grads = self.grad_fx(
161+
self.probe_params, embedding_sequence, labels, self.mask, self.num_heads, dropout=0., use_LN=self.use_LN,
162+
use_softmax=self.use_softmax
163+
)
164+
loss, predictions = outputs
165+
## adjust parameters of probe
166+
self.optim_params, self.probe_params = adam.adam_step(
167+
self.optim_params, self.probe_params, grads, eta=self.eta
168+
)
169+
return loss, predictions
170+
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import jax
2+
import numpy as np
3+
from ngclearn.utils.analysis.probe import Probe
4+
from ngclearn.utils.model_utils import drop_out, softmax, layer_normalize
5+
from jax import jit, random, numpy as jnp, lax, nn
6+
from functools import partial as bind
7+
import ngclearn.utils.weight_distribution as dist
8+
from ngclearn.utils.optim import adam, sgd
9+
10+
@bind(jax.jit, static_argnums=[2, 3])
11+
def run_linear_probe(params, x, use_softmax=False, use_LN=False):
12+
Wln_mu, Wln_scale, W, b = params
13+
_x = x
14+
if use_LN: ## normalize input vector to probe predictor
15+
_x = layer_normalize(_x, Wln_mu, Wln_scale)
16+
y_mu = (jnp.matmul(_x, W) + b)
17+
if use_softmax:
18+
y_mu = softmax(y_mu)
19+
return y_mu
20+
21+
@bind(jax.jit, static_argnums=[3, 4])
22+
def eval_linear_probe(params, x, y, use_softmax=True, use_LN=False):
23+
y_mu = run_linear_probe(params, x, use_softmax=use_softmax, use_LN=use_LN)
24+
e = y_mu - y
25+
if use_softmax: ## Multinoulli log likelihood for 1-of-K predictions
26+
L = -jnp.mean(jnp.sum(jnp.log(y_mu) * y, axis=1, keepdims=True))
27+
else: ## MSE for real-valued outputs
28+
L = jnp.sum(jnp.square(e)) * 1./x.shape[0]
29+
return L, y_mu
30+
#return y_mu, L, e
31+
32+
# @bind(jax.jit, static_argnums=[6, 7])
33+
# def calc_linear_probe_grad(x, y, params, eta, decay=0., l1_decay=0., use_softmax=False, use_LN=False):
34+
# y_mu, L, e = eval_linear_probe(params, x, y, use_softmax=use_softmax, use_LN=use_LN)
35+
# Wln_mu, Wln_scale, W, b = params
36+
# dW = jnp.matmul(x.T, e) + W * decay/eta + jnp.abs(W) * 0.5 * l1_decay/eta
37+
# db = jnp.sum(e, axis=0, keepdims=True)
38+
# dW = dW * (1. / x.shape[0])
39+
# db = db * (1. / x.shape[0])
40+
# return y_mu, L, [dW, db]
41+
42+
# @jit
43+
# def update_linear_probe(x, y, params, eta, decay=0., l1_decay=0., use_softmax=False):
44+
# y_mu, L, e = run_linear_probe(x, params, use_softmax=use_softmax)
45+
# W, b = params
46+
# dW = jnp.matmul(x.T, e)
47+
# db = jnp.sum(e, axis=0, keepdims=True)
48+
# W = W - dW * eta/x.shape[0] - W * decay/x.shape[0] - jnp.abs(W) * 0.5 * l1_decay/x.shape[0]
49+
# b = b - db * eta/x.shape[0]
50+
# return y_mu, L, [W, b]
51+
52+
class LinearProbe(Probe):
53+
"""
54+
Args:
55+
dkey: init seed key
56+
57+
source_seq_length: length of input sequence (e.g., height x width of the image feature)
58+
59+
input_dim: input dimensionality of probe
60+
61+
out_dim: output dimensionality of probe
62+
63+
batch_size: size of batches to process per internal call to update (or process)
64+
65+
use_LN: should layer normalization be used on incoming input vectors given to this probe?
66+
67+
use_softmax: should a softmax be applied to output of probe or not?
68+
69+
"""
70+
def __init__(
71+
self, dkey, source_seq_length, input_dim, out_dim, batch_size=1, use_LN=False, use_softmax=False, **kwargs
72+
):
73+
super().__init__(dkey, batch_size, **kwargs)
74+
self.dkey, *subkeys = random.split(self.dkey, 3)
75+
self.source_seq_length = source_seq_length
76+
self.input_dim = input_dim
77+
self.out_dim = out_dim
78+
self.use_softmax = use_softmax
79+
self.use_LN = use_LN
80+
self.l2_decay = 0.0001
81+
self.l1_decay = 0.000025
82+
## TODO: add in pre-built layer norm of inputs?
83+
84+
## set up classifier
85+
flat_input_dim = input_dim * source_seq_length
86+
weight_init = dist.fan_in_gaussian() # dist.gaussian(mu=0., sigma=0.05) # 0.02)
87+
Wln_mu = jnp.zeros((1, flat_input_dim))
88+
Wln_scale = jnp.ones((1, flat_input_dim))
89+
W = dist.initialize_params(subkeys[0], weight_init, (flat_input_dim, out_dim))
90+
b = jnp.zeros((1, out_dim))
91+
self.probe_params = [Wln_mu, Wln_scale, W, b]
92+
93+
## set up update rule/optimizer
94+
## set up gradient calculator
95+
self.grad_fx = jax.value_and_grad(eval_linear_probe, argnums=0, has_aux=True)
96+
self.optim_params = adam.adam_init(self.probe_params)
97+
self.eta = 0.001
98+
99+
def process(self, embeddings):
100+
_embeddings = embeddings
101+
if len(_embeddings.shape) > 2:
102+
flat_dim = embeddings.shape[1] * embeddings.shape[2]
103+
_embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
104+
outs = run_linear_probe(self.probe_params, _embeddings, use_softmax=self.use_softmax, use_LN=self.use_LN)
105+
return outs
106+
107+
def update(self, embeddings, labels):
108+
_embeddings = embeddings
109+
if len(_embeddings.shape) > 2:
110+
flat_dim = embeddings.shape[1] * embeddings.shape[2]
111+
_embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
112+
## compute adjustments to probe parameters
113+
# predictions, loss, grads = calc_linear_probe_grad(
114+
# self.probe_params, _embeddings, labels, self.eta, decay=self.l2_decay, l1_decay=self.l1_decay,
115+
# use_softmax=self.use_softmax, use_LN=self.use_LN
116+
# )
117+
outputs, grads = self.grad_fx(
118+
self.probe_params, _embeddings, labels, use_softmax=self.use_softmax, use_LN=self.use_LN
119+
)
120+
loss, predictions = outputs
121+
## adjust parameters of probe
122+
self.optim_params, self.probe_params = adam.adam_step(
123+
self.optim_params, self.probe_params, grads, eta=self.eta
124+
)
125+
return loss, predictions

ngclearn/utils/analysis/probe.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from jax import random, numpy as jnp
2+
3+
class Probe():
4+
"""
5+
General framework for an analysis probe (that may or may not be learnable in an iterative fashion).
6+
7+
Args:
8+
dkey: init seed key
9+
10+
batch_size: size of batches to process per internal call to update (or process)
11+
12+
"""
13+
def __init__(
14+
self, dkey, batch_size=4, **kwargs
15+
):
16+
#dkey, *subkeys = random.split(dkey, 3)
17+
self.dkey = dkey
18+
self.batch_size = batch_size
19+
20+
def process(self, embeddings):
21+
predictions = None
22+
return predictions
23+
24+
def update(self, embeddings, labels):
25+
L = predictions = None
26+
return L, predictions
27+
28+
def predict(self, data):
29+
_data = data
30+
if len(_data.shape) < 3:
31+
_data = jnp.expand_dims(_data, axis=1)
32+
33+
n_samples, seq_len, dim = _data.shape
34+
n_batches = int(n_samples / self.batch_size)
35+
s_ptr = 0
36+
e_ptr = self.batch_size
37+
Y_mu = []
38+
for b in range(n_batches):
39+
x_mb = _data[s_ptr:e_ptr, :, :] ## slice out 3D batch tensor
40+
s_ptr = e_ptr
41+
e_ptr += x_mb.shape[0]
42+
y_mu = self.process(x_mb)
43+
Y_mu.append(y_mu)
44+
Y_mu = jnp.concatenate(Y_mu, axis=0)
45+
return Y_mu
46+
47+
def fit(self, data, labels, n_iter=50):
48+
_data = data
49+
if len(_data.shape) < 3:
50+
_data = jnp.expand_dims(_data, axis=1)
51+
52+
n_samples, seq_len, dim = _data.shape
53+
n_batches = int(n_samples / self.batch_size)
54+
55+
Y_mu = []
56+
_Y = None
57+
for iter in range(n_iter):
58+
## shuffle data (to ensure i.i.d. across sequences)
59+
self.dkey, *subkeys = random.split(self.dkey, 2)
60+
ptrs = random.permutation(subkeys[0], n_samples)
61+
_X = _data[ptrs, :, :]
62+
_Y = labels[ptrs, :]
63+
## run one epoch over data tensors
64+
L = 0.
65+
Ns = 0.
66+
67+
s_ptr = 0
68+
e_ptr = self.batch_size
69+
for b in range(n_batches):
70+
x_mb = _X[s_ptr:e_ptr, :, :] ## slice out 3D batch tensor
71+
y_mb = _Y[s_ptr:e_ptr, :]
72+
s_ptr = e_ptr
73+
e_ptr += x_mb.shape[0]
74+
Ns += x_mb.shape[0]
75+
76+
_L, py = self.update(x_mb, y_mb)
77+
L = _L + L
78+
print(f"\r{iter} L = {L/Ns}", end="") # p(y|z):\n{py}")
79+
if iter == n_iter-1:
80+
Y_mu.append(py)
81+
print()
82+
if iter == n_iter - 1:
83+
Y_mu = jnp.concatenate(Y_mu, axis=0)
84+
return Y_mu, _Y

0 commit comments

Comments
 (0)