Skip to content

Commit f402d98

Browse files
committed
update attentive probe code
1 parent 8a36e40 commit f402d98

File tree

1 file changed

+57
-21
lines changed

1 file changed

+57
-21
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
7878
"""
7979
Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).
8080
81-
Args:
81+
Args:
8282
params: parameters tuple/list of probe
8383
8484
encodings: input encoding vectors/data
@@ -98,18 +98,35 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
9898
"""
9999
# encoded_image_feature: (B, hw, dim)
100100
#learnable_query, *_params) = params
101-
learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout, Whid, bhid, Wln_mu, Wln_scale, Wy, by = params
102-
attn_params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
103-
features = cross_attention(attn_params, learnable_query, encodings, mask, n_heads, dropout)
101+
learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout,\
102+
Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts, Wlnattn_mu,\
103+
Wlnattn_scale, Whid1, bhid1, Wln_mu1, Wln_scale1, Whid2,\
104+
bhid2, Wln_mu2, Wln_scale2, Whid3, bhid3, Wln_mu3, Wln_scale3, Wy, by = params
105+
cross_attn_params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
106+
features = cross_attention(cross_attn_params, learnable_query, encodings, mask, n_heads, dropout)
107+
# Perform a single self-attention block here
108+
# Self-Attention
109+
self_attn_params = (Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts)
110+
skip = features
111+
if use_LN:
112+
features = layer_normalize(features, Wlnattn_mu, Wlnattn_scale)
113+
features = cross_attention(self_attn_params, features, features, None, n_heads, dropout)
114+
features = features + skip
104115
features = features[:, 0] # (B, 1, dim) => (B, dim)
105116
# MLP
106-
residual = features
117+
skip = features
107118
if use_LN: ## normalize hidden layer output of probe predictor
108-
features = layer_normalize(features, Wln_mu, Wln_scale)
109-
features = jnp.matmul((features), Whid) + bhid
119+
features = layer_normalize(features, Wln_mu1, Wln_scale1)
120+
features = jnp.matmul((features), Whid1) + bhid1
110121
features = gelu(features)
111-
features = residual + features
112-
122+
if use_LN: ## normalize hidden layer output of probe predictor
123+
features = layer_normalize(features, Wln_mu2, Wln_scale2)
124+
features = jnp.matmul((features), Whid2) + bhid2
125+
features = gelu(features)
126+
if use_LN: ## normalize hidden layer output of probe predictor
127+
features = layer_normalize(features, Wln_mu3, Wln_scale3)
128+
features = jnp.matmul((features), Whid3) + bhid3
129+
features = features + skip
113130
outs = jnp.matmul(features, Wy) + by
114131
if use_softmax: ## apply softmax output nonlinearity
115132
outs = softmax(outs)
@@ -183,11 +200,12 @@ class AttentiveProbe(Probe):
183200
"""
184201
def __init__(
185202
self, dkey, source_seq_length, input_dim, out_dim, num_heads=8, attn_dim=64,
186-
target_seq_length=1, learnable_query_dim=31, batch_size=1, hid_dim=32, use_LN=True, use_softmax=True, **kwargs
203+
target_seq_length=1, learnable_query_dim=32, batch_size=1, hid_dim=32, use_LN=True, use_softmax=True, **kwargs
187204
):
188205
super().__init__(dkey, batch_size, **kwargs)
189206
assert attn_dim % num_heads == 0, f"`attn_dim` must be divisible by `num_heads`. Got {attn_dim} and {num_heads}."
190-
self.dkey, *subkeys = random.split(self.dkey, 12)
207+
assert learnable_query_dim % num_heads == 0, f"`learnable_query_dim` must be divisible by `num_heads`. Got {learnable_query_dim} and {num_heads}."
208+
self.dkey, *subkeys = random.split(self.dkey, 25)
191209
self.num_heads = num_heads
192210
self.source_seq_length = source_seq_length
193211
self.input_dim = input_dim
@@ -205,19 +223,37 @@ def __init__(
205223
bv = random.normal(subkeys[5], (1, attn_dim)) * sigma
206224
Wout = random.normal(subkeys[6], (attn_dim, learnable_query_dim)) * sigma
207225
bout = random.normal(subkeys[7], (1, learnable_query_dim)) * sigma
208-
#params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
226+
cross_attn_params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
227+
Wqs = random.normal(subkeys[8], (learnable_query_dim, learnable_query_dim)) * sigma
228+
bqs = random.normal(subkeys[9], (1, learnable_query_dim)) * sigma
229+
Wks = random.normal(subkeys[10], (learnable_query_dim, learnable_query_dim)) * sigma
230+
bks = random.normal(subkeys[11], (1, learnable_query_dim)) * sigma
231+
Wvs = random.normal(subkeys[12], (learnable_query_dim, learnable_query_dim)) * sigma
232+
bvs = random.normal(subkeys[13], (1, learnable_query_dim)) * sigma
233+
Wouts = random.normal(subkeys[14], (learnable_query_dim, learnable_query_dim)) * sigma
234+
bouts = random.normal(subkeys[15], (1, learnable_query_dim)) * sigma
235+
Wlnattn_mu = jnp.zeros((1, learnable_query_dim))
236+
Wlnattn_scale = jnp.ones((1, learnable_query_dim))
237+
self_attn_params = (Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts, Wlnattn_mu, Wlnattn_scale)
209238
learnable_query = jnp.zeros((batch_size, 1, learnable_query_dim)) # (B, T, D)
210-
#self.all_params = (learnable_query, *params)
211239
self.mask = np.zeros((batch_size, target_seq_length, source_seq_length)).astype(bool) ## mask tensor
212240
## MLP parameters
213-
Whid = random.normal(subkeys[8], (learnable_query_dim, learnable_query_dim)) * sigma
214-
bhid = random.normal(subkeys[9], (1, learnable_query_dim)) * sigma
215-
Wln_mu = jnp.zeros((1, learnable_query_dim))
216-
Wln_scale = jnp.ones((1, learnable_query_dim))
217-
Wy = random.normal(subkeys[8], (learnable_query_dim, out_dim)) * sigma
218-
by = random.normal(subkeys[9], (1, out_dim)) * sigma
219-
#mlp_params = (Whid, bhid, Wln_mu, Wln_scale, Wy, by)
220-
self.probe_params = (learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout, Whid, bhid, Wln_mu, Wln_scale, Wy, by)
241+
Whid1 = random.normal(subkeys[16], (learnable_query_dim, learnable_query_dim)) * sigma
242+
bhid1 = random.normal(subkeys[17], (1, learnable_query_dim)) * sigma
243+
Wln_mu1 = jnp.zeros((1, learnable_query_dim))
244+
Wln_scale1 = jnp.ones((1, learnable_query_dim))
245+
Whid2 = random.normal(subkeys[18], (learnable_query_dim, learnable_query_dim * 4)) * sigma
246+
bhid2 = random.normal(subkeys[19], (1, learnable_query_dim * 4)) * sigma
247+
Wln_mu2 = jnp.zeros((1, learnable_query_dim))
248+
Wln_scale2 = jnp.ones((1, learnable_query_dim))
249+
Whid3 = random.normal(subkeys[20], (learnable_query_dim * 4, learnable_query_dim)) * sigma
250+
bhid3 = random.normal(subkeys[21], (1, learnable_query_dim)) * sigma
251+
Wln_mu3 = jnp.zeros((1, learnable_query_dim * 4))
252+
Wln_scale3 = jnp.ones((1, learnable_query_dim * 4))
253+
Wy = random.normal(subkeys[22], (learnable_query_dim, out_dim)) * sigma
254+
by = random.normal(subkeys[23], (1, out_dim)) * sigma
255+
mlp_params = (Whid1, bhid1, Wln_mu1, Wln_scale1, Whid2, bhid2, Wln_mu2, Wln_scale2, Whid3, bhid3, Wln_mu3, Wln_scale3, Wy, by)
256+
self.probe_params = (learnable_query, *cross_attn_params, *self_attn_params, *mlp_params)
221257

222258
## set up gradient calculator
223259
self.grad_fx = jax.value_and_grad(eval_attention_probe, argnums=0, has_aux=True)

0 commit comments

Comments
 (0)