Skip to content

Commit 8f75b0d

Browse files
authored
Merge pull request #93 from NACLab/analysis_tools
Merge of analysis tools feature branch to main
2 parents ffd8f0e + 08b4d12 commit 8f75b0d

File tree

6 files changed

+789
-21
lines changed

6 files changed

+789
-21
lines changed

ngclearn/components/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell
1919
from .neurons.spiking.izhikevichCell import IzhikevichCell
2020
from .neurons.spiking.RAFCell import RAFCell
21+
2122
## point to transformer/operater component types
2223
from .other.varTrace import VarTrace
2324
from .other.expKernel import ExpKernel
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## point to supported analysis probes
2+
from .linear_probe import LinearProbe
3+
from .attentive_probe import AttentiveProbe
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
import jax
2+
import numpy as np
3+
from ngclearn.utils.analysis.probe import Probe
4+
from ngclearn.utils.model_utils import drop_out, softmax, gelu, layer_normalize
5+
from ngclearn.utils.optim import adam
6+
from jax import jit, random, numpy as jnp, lax, nn
7+
from functools import partial as bind
8+
9+
def masked_fill(x: jax.Array, mask: jax.Array, value=0) -> jax.Array:
10+
"""
11+
Return an output with masked condition, with non-masked value
12+
be the other value
13+
14+
Args:
15+
x (jax.Array): _description_
16+
mask (jax.Array): _description_
17+
value (int, optional): _description_. Defaults to 0.
18+
19+
Returns:
20+
jax.Array: _description_
21+
"""
22+
return jnp.where(mask, jnp.broadcast_to(value, x.shape), x)
23+
24+
@bind(jax.jit, static_argnums=[5, 6])
25+
def cross_attention(dkey, params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array, n_heads: int=8, dropout_rate: float=0.0) -> jax.Array:
26+
"""
27+
Run cross-attention function given a list of parameters and two sequences (x1 and x2).
28+
The function takes in a query sequence x1 and a key-value sequence x2, and returns an output of the same shape as x1.
29+
T is the length of the query sequence, and S is the length of the key-value sequence.
30+
Dq is the dimension of the query sequence, and Dkv is the dimension of the key-value sequence.
31+
H is the number of attention heads.
32+
33+
Args:
34+
dkey: JAX key to trigger any internal noise (drop-out)
35+
36+
params (tuple): tuple of parameters
37+
38+
x1 (jax.Array): query sequence. Shape: (B, T, Dq)
39+
40+
x2 (jax.Array): key-value sequence. Shape: (B, S, Dkv)
41+
42+
mask (jax.Array): mask tensor. Shape: (B, T, S)
43+
44+
n_heads (int, optional): number of attention heads. Defaults to 8.
45+
46+
dropout_rate (float, optional): dropout rate. Defaults to 0.0.
47+
48+
Returns:
49+
jax.Array: output of cross-attention
50+
"""
51+
B, T, Dq = x1.shape # The original shape
52+
_, S, Dkv = x2.shape
53+
# in here we attend x2 to x1
54+
Wq, bq, Wk, bk, Wv, bv, Wout, bout = params
55+
# projection
56+
q = x1 @ Wq + bq # normal linear transformation (B, T, D)
57+
k = x2 @ Wk + bk # normal linear transformation (B, S, D)
58+
v = x2 @ Wv + bv # normal linear transformation (B, S, D)
59+
hidden = q.shape[-1]
60+
_hidden = hidden // n_heads
61+
q = q.reshape((B, T, n_heads, _hidden)).transpose([0, 2, 1, 3]) # (B, H, T, D)
62+
k = k.reshape((B, S, n_heads, _hidden)).transpose([0, 2, 1, 3]) # (B, H, T, D)
63+
v = v.reshape((B, S, n_heads, _hidden)).transpose([0, 2, 1, 3]) # (B, H, T, D)
64+
score = jnp.einsum("BHTE,BHSE->BHTS", q, k) / jnp.sqrt(_hidden) # Q @ KT / ||d||; d = D // n_heads
65+
if mask is not None:
66+
Tq, Tk = q.shape[2], k.shape[2]
67+
assert mask.shape == (B, Tq, Tk), (mask.shape, (B, Tq, Tk))
68+
_mask = mask.reshape((B, 1, Tq, Tk)) # 'b tq tk -> b 1 tq tk'
69+
score = masked_fill(score, _mask, value=-jnp.inf) # basically masking out all must-unattended values
70+
score = jax.nn.softmax(score, axis=-1) # (B, H, T, S)
71+
score = score.astype(q.dtype) # (B, H, T, S)
72+
if dropout_rate > 0.:
73+
score, _ = drop_out(dkey, score, rate=dropout_rate) ## NOTE: normally you apply dropout here
74+
attention = jnp.einsum("BHTS,BHSE->BHTE", score, v) # (B, T, H, E)
75+
attention = attention.transpose([0, 2, 1, 3]).reshape((B, T, -1)) # (B, T, H, E) => (B, T, D)
76+
return attention @ Wout + bout # (B, T, Dq)
77+
78+
@bind(jax.jit, static_argnums=[4, 5, 6, 7, 8])
79+
def run_attention_probe(
80+
dkey, params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False,
81+
use_softmax=True
82+
):
83+
"""
84+
Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).
85+
86+
Args:
87+
dkey: JAX key for any internal noise to be applied
88+
89+
params: parameters tuple/list of probe
90+
91+
encodings: input encoding vectors/data
92+
93+
mask: optional mask to be applied to internal cross-attention
94+
95+
n_heads: number of attention heads
96+
97+
dropout: if >0, triggers drop-out applied internally to cross-attention
98+
99+
use_LN: use layer normalization?
100+
101+
use_LN_input: use layer normalization on input encodings?
102+
103+
use_softmax: should softmax be applied to output of attention probe? (useful for classification)
104+
105+
Returns:
106+
output scores/probabilities, cross-attention (hidden) features
107+
"""
108+
# Two separate dkeys for each dropout in two cross attention
109+
dkey1, dkey2 = random.split(dkey, 2)
110+
# encoded_image_feature: (B, hw, dim)
111+
#learnable_query, *_params) = params
112+
learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout,\
113+
Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts, Wlnattn_mu,\
114+
Wlnattn_scale, Whid1, bhid1, Wln_mu1, Wln_scale1, Whid2,\
115+
bhid2, Wln_mu2, Wln_scale2, Whid3, bhid3, Wln_mu3, Wln_scale3,\
116+
Wy, by, ln_in_mu, ln_in_scale, ln_in_mu2, ln_in_scale2 = params
117+
cross_attn_params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
118+
if use_LN_input:
119+
learnable_query = layer_normalize(learnable_query, ln_in_mu, ln_in_scale)
120+
encodings = layer_normalize(encodings, ln_in_mu2, ln_in_scale2)
121+
features = cross_attention(dkey1, cross_attn_params, learnable_query, encodings, mask, n_heads, dropout)
122+
# Perform a single self-attention block here
123+
# Self-Attention
124+
self_attn_params = (Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts)
125+
skip = features
126+
if use_LN:
127+
features = layer_normalize(features, Wlnattn_mu, Wlnattn_scale)
128+
features = cross_attention(dkey2, self_attn_params, features, features, None, n_heads, dropout)
129+
features = features + skip
130+
features = features[:, 0] # (B, 1, dim) => (B, dim)
131+
# MLP
132+
skip = features
133+
if use_LN: ## normalize hidden layer output of probe predictor
134+
features = layer_normalize(features, Wln_mu1, Wln_scale1)
135+
features = jnp.matmul((features), Whid1) + bhid1
136+
features = gelu(features)
137+
if use_LN: ## normalize hidden layer output of probe predictor
138+
features = layer_normalize(features, Wln_mu2, Wln_scale2)
139+
features = jnp.matmul((features), Whid2) + bhid2
140+
features = gelu(features)
141+
if use_LN: ## normalize hidden layer output of probe predictor
142+
features = layer_normalize(features, Wln_mu3, Wln_scale3)
143+
features = jnp.matmul((features), Whid3) + bhid3
144+
features = features + skip
145+
outs = jnp.matmul(features, Wy) + by
146+
if use_softmax: ## apply softmax output nonlinearity
147+
# NOTE: Viet: please check the softmax function, it might potentially
148+
# cause the gradient to be nan since there is a potential division by zero
149+
outs = jax.nn.softmax(outs)
150+
return outs, features
151+
152+
@bind(jax.jit, static_argnums=[5, 6, 7, 8, 9])
153+
def eval_attention_probe(dkey, params, encodings, labels, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False, use_softmax=True):
154+
"""
155+
Runs and evaluates the nonlinear attentive probe given a paired set of encoding vectors and externally assigned
156+
labels/regression targets.
157+
158+
Args:
159+
dkey: JAX key to trigger any internal noise (as in drop-out)
160+
161+
params: parameters tuple/list of probe
162+
163+
encodings: input encoding vectors/data
164+
165+
labels: output target values (e.g., labels, regression target vectors)
166+
167+
mask: optional mask to be applied to internal cross-attention
168+
169+
n_heads: number of attention heads
170+
171+
dropout: if >0, triggers drop-out applied internally to cross-attention
172+
173+
use_LN: use layer normalization?
174+
175+
use_softmax: should softmax be applied to output of attention probe? (useful for classification)
176+
177+
Returns:
178+
current loss value, output scores/probabilities
179+
"""
180+
# encodings: (B, hw, dim)
181+
outs, _ = run_attention_probe(dkey, params, encodings, mask, n_heads, dropout, use_LN, use_LN_input, use_softmax)
182+
if use_softmax: ## Multinoulli log likelihood for 1-of-K predictions
183+
L = -jnp.mean(jnp.sum(jnp.log(outs.clip(min=1e-5)) * labels, axis=1, keepdims=True))
184+
else: ## MSE for real-valued outputs
185+
L = jnp.mean(jnp.sum(jnp.square(outs - labels), axis=1, keepdims=True))
186+
return L, outs #, features
187+
188+
class AttentiveProbe(Probe):
189+
"""
190+
This implements a nonlinear attentive probe, which is useful for evaluating the quality of
191+
encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot
192+
encodings or real-valued vector regression targets).
193+
194+
Args:
195+
dkey: init seed key
196+
197+
source_seq_length: length of input sequence (e.g., height x width of the image feature)
198+
199+
input_dim: input dimensionality of probe
200+
201+
out_dim: output dimensionality of probe
202+
203+
num_heads: number of cross-attention heads
204+
205+
head_dim: output dimensionality of each cross-attention head
206+
207+
target_seq_length: to pool, we set it at one (or map the source sequence to the target sequence of length 1)
208+
209+
learnable_query_dim: target sequence dim (output dimension of cross-attention portion of probe)
210+
211+
batch_size: size of batches to process per internal call to update (or process)
212+
213+
hid_dim: dimensionality of hidden layer(s) of MLP portion of probe
214+
215+
use_LN: should layer normalization be used within MLP portions of probe or not?
216+
217+
use_softmax: should a softmax be applied to output of probe or not?
218+
219+
"""
220+
def __init__(
221+
self, dkey, source_seq_length, input_dim, out_dim, num_heads=8, attn_dim=64,
222+
target_seq_length=1, learnable_query_dim=32, batch_size=1, hid_dim=32,
223+
use_LN=True, use_LN_input=False, use_softmax=True, dropout=0.5, eta=0.0002,
224+
eta_decay=0.0, min_eta=1e-5, **kwargs
225+
):
226+
super().__init__(dkey, batch_size, **kwargs)
227+
assert attn_dim % num_heads == 0, f"`attn_dim` must be divisible by `num_heads`. Got {attn_dim} and {num_heads}."
228+
assert learnable_query_dim % num_heads == 0, f"`learnable_query_dim` must be divisible by `num_heads`. Got {learnable_query_dim} and {num_heads}."
229+
self.dkey, *subkeys = random.split(self.dkey, 26)
230+
self.num_heads = num_heads
231+
self.source_seq_length = source_seq_length
232+
self.input_dim = input_dim
233+
self.out_dim = out_dim
234+
self.use_softmax = use_softmax
235+
self.use_LN = use_LN
236+
self.use_LN_input = use_LN_input
237+
self.dropout = dropout
238+
239+
sigma = 0.02
240+
## cross-attention parameters
241+
Wq = random.normal(subkeys[0], (learnable_query_dim, attn_dim)) * sigma
242+
bq = random.normal(subkeys[1], (1, attn_dim)) * sigma
243+
Wk = random.normal(subkeys[2], (input_dim, attn_dim)) * sigma
244+
bk = random.normal(subkeys[3], (1, attn_dim)) * sigma
245+
Wv = random.normal(subkeys[4], (input_dim, attn_dim)) * sigma
246+
bv = random.normal(subkeys[5], (1, attn_dim)) * sigma
247+
Wout = random.normal(subkeys[6], (attn_dim, learnable_query_dim)) * sigma
248+
bout = random.normal(subkeys[7], (1, learnable_query_dim)) * sigma
249+
cross_attn_params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
250+
Wqs = random.normal(subkeys[8], (learnable_query_dim, learnable_query_dim)) * sigma
251+
bqs = random.normal(subkeys[9], (1, learnable_query_dim)) * sigma
252+
Wks = random.normal(subkeys[10], (learnable_query_dim, learnable_query_dim)) * sigma
253+
bks = random.normal(subkeys[11], (1, learnable_query_dim)) * sigma
254+
Wvs = random.normal(subkeys[12], (learnable_query_dim, learnable_query_dim)) * sigma
255+
bvs = random.normal(subkeys[13], (1, learnable_query_dim)) * sigma
256+
Wouts = random.normal(subkeys[14], (learnable_query_dim, learnable_query_dim)) * sigma
257+
bouts = random.normal(subkeys[15], (1, learnable_query_dim)) * sigma
258+
Wlnattn_mu = jnp.zeros((1, learnable_query_dim)) ## LN parameter (applied to output of attention)
259+
Wlnattn_scale = jnp.ones((1, learnable_query_dim)) ## LN parameter (applied to output of attention)
260+
self_attn_params = (Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts, Wlnattn_mu, Wlnattn_scale)
261+
learnable_query = jnp.zeros((batch_size, 1, learnable_query_dim)) # (B, T, D)
262+
self.mask = np.zeros((self.batch_size, target_seq_length, source_seq_length)).astype(bool) ## mask tensor
263+
self.dev_mask = np.zeros((self.dev_batch_size, target_seq_length, source_seq_length)).astype(bool)
264+
## MLP parameters
265+
Whid1 = random.normal(subkeys[16], (learnable_query_dim, learnable_query_dim)) * sigma
266+
bhid1 = random.normal(subkeys[17], (1, learnable_query_dim)) * sigma
267+
Wln_mu1 = jnp.zeros((1, learnable_query_dim)) ## LN parameter
268+
Wln_scale1 = jnp.ones((1, learnable_query_dim)) ## LN parameter
269+
Whid2 = random.normal(subkeys[18], (learnable_query_dim, learnable_query_dim * 4)) * sigma
270+
bhid2 = random.normal(subkeys[19], (1, learnable_query_dim * 4)) * sigma
271+
Wln_mu2 = jnp.zeros((1, learnable_query_dim)) ## LN parameter
272+
Wln_scale2 = jnp.ones((1, learnable_query_dim)) ## LN parameter
273+
Whid3 = random.normal(subkeys[20], (learnable_query_dim * 4, learnable_query_dim)) * sigma
274+
bhid3 = random.normal(subkeys[21], (1, learnable_query_dim)) * sigma
275+
Wln_mu3 = jnp.zeros((1, learnable_query_dim * 4)) ## LN parameter
276+
Wln_scale3 = jnp.ones((1, learnable_query_dim * 4)) ## LN parameter
277+
Wy = random.normal(subkeys[22], (learnable_query_dim, out_dim)) * sigma
278+
by = random.normal(subkeys[23], (1, out_dim)) * sigma
279+
mlp_params = (Whid1, bhid1, Wln_mu1, Wln_scale1, Whid2, bhid2, Wln_mu2, Wln_scale2, Whid3, bhid3, Wln_mu3, Wln_scale3, Wy, by)
280+
# Finally, define ln for the input to the attention
281+
ln_in_mu = jnp.zeros((1, learnable_query_dim)) ## LN parameter
282+
ln_in_scale = jnp.ones((1, learnable_query_dim)) ## LN parameter
283+
ln_in_mu2 = jnp.zeros((1, input_dim)) ## LN parameter
284+
ln_in_scale2 = jnp.ones((1, input_dim)) ## LN parameter
285+
ln_in_params = (ln_in_mu, ln_in_scale, ln_in_mu2, ln_in_scale2)
286+
self.probe_params = (learnable_query, *cross_attn_params, *self_attn_params, *mlp_params, *ln_in_params)
287+
288+
## set up gradient calculator
289+
self.grad_fx = jax.value_and_grad(eval_attention_probe, argnums=1, has_aux=True) #, allow_int=True)
290+
## set up update rule/optimizer
291+
self.optim_params = adam.adam_init(self.probe_params)
292+
# Learning rate scheduling
293+
self.eta = eta #0.001
294+
self.eta_decay = eta_decay
295+
self.min_eta = min_eta
296+
297+
# Finally, the dkey for the noise_key
298+
self.noise_key = subkeys[24]
299+
300+
def process(self, embeddings, dkey=None):
301+
# noise_key = None
302+
noise_key = self.noise_key
303+
if dkey is not None:
304+
dkey, *subkeys = random.split(dkey, 2)
305+
noise_key = subkeys[0]
306+
outs, feats = run_attention_probe(
307+
noise_key, self.probe_params, embeddings, self.dev_mask, self.num_heads, 0.0,
308+
use_LN=self.use_LN, use_LN_input=self.use_LN_input, use_softmax=self.use_softmax
309+
)
310+
return outs
311+
312+
def update(self, embeddings, labels, dkey=None):
313+
# noise_key = None
314+
noise_key = self.noise_key
315+
if dkey is not None:
316+
dkey, *subkeys = random.split(dkey, 2)
317+
noise_key = subkeys[0]
318+
outputs, grads = self.grad_fx(
319+
noise_key, self.probe_params, embeddings, labels, self.mask, self.num_heads, dropout=self.dropout,
320+
use_LN=self.use_LN, use_LN_input=self.use_LN_input, use_softmax=self.use_softmax
321+
)
322+
loss, predictions = outputs
323+
## adjust parameters of probe
324+
self.optim_params, self.probe_params = adam.adam_step(
325+
self.optim_params, self.probe_params, grads, eta=self.eta
326+
)
327+
328+
self.eta = max(self.min_eta, self.eta - self.eta_decay * self.eta)
329+
return loss, predictions
330+

0 commit comments

Comments
 (0)