Skip to content

Commit d0df86e

Browse files
committed
change heads_dim to attn_dim, and modify the mlp to be as similar as possible to the attentive probing pattern
1 parent 247de74 commit d0df86e

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,15 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
102102
attn_params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
103103
features = cross_attention(attn_params, learnable_query, encodings, mask, n_heads, dropout)
104104
features = features[:, 0] # (B, 1, dim) => (B, dim)
105-
hids = jnp.matmul((features + learnable_query[:, 0]), Whid) + bhid
106-
hids = gelu(hids)
105+
# MLP
106+
residual = features
107107
if use_LN: ## normalize hidden layer output of probe predictor
108-
hids = layer_normalize(hids, Wln_mu, Wln_scale)
109-
outs = jnp.matmul(hids, Wy) + by
108+
features = layer_normalize(features, Wln_mu, Wln_scale)
109+
features = jnp.matmul((features), Whid) + bhid
110+
features = gelu(features)
111+
features = residual + features
112+
113+
outs = jnp.matmul(features, Wy) + by
110114
if use_softmax: ## apply softmax output nonlinearity
111115
outs = softmax(outs)
112116
return outs, features
@@ -178,10 +182,11 @@ class AttentiveProbe(Probe):
178182
179183
"""
180184
def __init__(
181-
self, dkey, source_seq_length, input_dim, out_dim, num_heads=8, head_dim=64,
185+
self, dkey, source_seq_length, input_dim, out_dim, num_heads=8, attn_dim=64,
182186
target_seq_length=1, learnable_query_dim=31, batch_size=1, hid_dim=32, use_LN=True, use_softmax=True, **kwargs
183187
):
184188
super().__init__(dkey, batch_size, **kwargs)
189+
assert attn_dim % num_heads == 0, f"`attn_dim` must be divisible by `num_heads`. Got {attn_dim} and {num_heads}."
185190
self.dkey, *subkeys = random.split(self.dkey, 12)
186191
self.num_heads = num_heads
187192
self.source_seq_length = source_seq_length
@@ -192,24 +197,24 @@ def __init__(
192197

193198
sigma = 0.05
194199
## cross-attention parameters
195-
Wq = random.normal(subkeys[0], (learnable_query_dim, head_dim)) * sigma
196-
bq = random.normal(subkeys[1], (1, head_dim)) * sigma
197-
Wk = random.normal(subkeys[2], (input_dim, head_dim)) * sigma
198-
bk = random.normal(subkeys[3], (1, head_dim)) * sigma
199-
Wv = random.normal(subkeys[4], (input_dim, head_dim)) * sigma
200-
bv = random.normal(subkeys[5], (1, head_dim)) * sigma
201-
Wout = random.normal(subkeys[6], (head_dim, learnable_query_dim)) * sigma
200+
Wq = random.normal(subkeys[0], (learnable_query_dim, attn_dim)) * sigma
201+
bq = random.normal(subkeys[1], (1, attn_dim)) * sigma
202+
Wk = random.normal(subkeys[2], (input_dim, attn_dim)) * sigma
203+
bk = random.normal(subkeys[3], (1, attn_dim)) * sigma
204+
Wv = random.normal(subkeys[4], (input_dim, attn_dim)) * sigma
205+
bv = random.normal(subkeys[5], (1, attn_dim)) * sigma
206+
Wout = random.normal(subkeys[6], (attn_dim, learnable_query_dim)) * sigma
202207
bout = random.normal(subkeys[7], (1, learnable_query_dim)) * sigma
203208
#params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
204209
learnable_query = jnp.zeros((batch_size, 1, learnable_query_dim)) # (B, T, D)
205210
#self.all_params = (learnable_query, *params)
206211
self.mask = np.zeros((batch_size, target_seq_length, source_seq_length)).astype(bool) ## mask tensor
207212
## MLP parameters
208-
Whid = random.normal(subkeys[8], (learnable_query_dim, hid_dim)) * sigma
209-
bhid = random.normal(subkeys[9], (1, hid_dim)) * sigma
210-
Wln_mu = jnp.zeros((1, hid_dim))
211-
Wln_scale = jnp.ones((1, hid_dim))
212-
Wy = random.normal(subkeys[8], (hid_dim, out_dim)) * sigma
213+
Whid = random.normal(subkeys[8], (learnable_query_dim, learnable_query_dim)) * sigma
214+
bhid = random.normal(subkeys[9], (1, learnable_query_dim)) * sigma
215+
Wln_mu = jnp.zeros((1, learnable_query_dim))
216+
Wln_scale = jnp.ones((1, learnable_query_dim))
217+
Wy = random.normal(subkeys[8], (learnable_query_dim, out_dim)) * sigma
213218
by = random.normal(subkeys[9], (1, out_dim)) * sigma
214219
#mlp_params = (Whid, bhid, Wln_mu, Wln_scale, Wy, by)
215220
self.probe_params = (learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout, Whid, bhid, Wln_mu, Wln_scale, Wy, by)

0 commit comments

Comments
 (0)