Skip to content

Commit aeabf61

Browse files
committed
update attentive probe with input layer norm
1 parent 099c588 commit aeabf61

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array
7373
attention = attention.transpose([0, 2, 1, 3]).reshape((B, T, -1)) # (B, T, H, E) => (B, T, D)
7474
return attention @ Wout + bout # (B, T, Dq)
7575

76-
@bind(jax.jit, static_argnums=[3, 4, 5, 6])
77-
def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_softmax=True):
76+
@bind(jax.jit, static_argnums=[3, 4, 5, 6, 7])
77+
def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=True, use_softmax=True):
7878
"""
7979
Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).
8080
@@ -101,8 +101,11 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
101101
learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout,\
102102
Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts, Wlnattn_mu,\
103103
Wlnattn_scale, Whid1, bhid1, Wln_mu1, Wln_scale1, Whid2,\
104-
bhid2, Wln_mu2, Wln_scale2, Whid3, bhid3, Wln_mu3, Wln_scale3, Wy, by = params
104+
bhid2, Wln_mu2, Wln_scale2, Whid3, bhid3, Wln_mu3, Wln_scale3,\
105+
Wy, by, ln_in_mu, ln_in_scale = params
105106
cross_attn_params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
107+
if use_LN_input:
108+
learnable_query = layer_normalize(learnable_query, ln_in_mu, ln_in_scale)
106109
features = cross_attention(cross_attn_params, learnable_query, encodings, mask, n_heads, dropout)
107110
# Perform a single self-attention block here
108111
# Self-Attention
@@ -200,7 +203,7 @@ class AttentiveProbe(Probe):
200203
"""
201204
def __init__(
202205
self, dkey, source_seq_length, input_dim, out_dim, num_heads=8, attn_dim=64,
203-
target_seq_length=1, learnable_query_dim=32, batch_size=1, hid_dim=32, use_LN=True, use_softmax=True, **kwargs
206+
target_seq_length=1, learnable_query_dim=32, batch_size=1, hid_dim=32, use_LN=True, use_LN_input=True, use_softmax=True, **kwargs
204207
):
205208
super().__init__(dkey, batch_size, **kwargs)
206209
assert attn_dim % num_heads == 0, f"`attn_dim` must be divisible by `num_heads`. Got {attn_dim} and {num_heads}."
@@ -212,6 +215,7 @@ def __init__(
212215
self.out_dim = out_dim
213216
self.use_softmax = use_softmax
214217
self.use_LN = use_LN
218+
self.use_LN_input = use_LN_input
215219

216220
sigma = 0.05
217221
## cross-attention parameters
@@ -254,7 +258,11 @@ def __init__(
254258
Wy = random.normal(subkeys[22], (learnable_query_dim, out_dim)) * sigma
255259
by = random.normal(subkeys[23], (1, out_dim)) * sigma
256260
mlp_params = (Whid1, bhid1, Wln_mu1, Wln_scale1, Whid2, bhid2, Wln_mu2, Wln_scale2, Whid3, bhid3, Wln_mu3, Wln_scale3, Wy, by)
257-
self.probe_params = (learnable_query, *cross_attn_params, *self_attn_params, *mlp_params)
261+
# Finally, define ln for the input to the attention
262+
ln_in_mu = jnp.zeros((1, learnable_query_dim)) ## LN parameter
263+
ln_in_scale = jnp.ones((1, learnable_query_dim)) ## LN parameter
264+
ln_in_params = (ln_in_mu, ln_in_scale)
265+
self.probe_params = (learnable_query, *cross_attn_params, *self_attn_params, *mlp_params, *ln_in_params)
258266

259267
## set up gradient calculator
260268
self.grad_fx = jax.value_and_grad(eval_attention_probe, argnums=0, has_aux=True)
@@ -294,8 +302,9 @@ def update(self, embedding_sequence, labels, dkey=None):
294302
"""
295303
# TODO: put in dkey to facilitate dropout
296304
## compute partial derivatives / adjustments to probe parameters
305+
# NOTE: Viet: Change back to 0.0 for now for the code to run
297306
outputs, grads = self.grad_fx(
298-
self.probe_params, embedding_sequence, labels, self.mask, self.num_heads, dropout=0.5, use_LN=self.use_LN,
307+
self.probe_params, embedding_sequence, labels, self.mask, self.num_heads, dropout=0.0, use_LN=self.use_LN,
299308
use_softmax=self.use_softmax
300309
)
301310
loss, predictions = outputs

0 commit comments

Comments
 (0)