Skip to content

Commit 099c588

Browse files
author
Alexander Ororbia
committed
minor edits to attn probe
1 parent 155d830 commit 099c588

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def __init__(
236236
Wlnattn_scale = jnp.ones((1, learnable_query_dim)) ## LN parameter (applied to output of attention)
237237
self_attn_params = (Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts, Wlnattn_mu, Wlnattn_scale)
238238
learnable_query = jnp.zeros((batch_size, 1, learnable_query_dim)) # (B, T, D)
239-
self.mask = np.zeros((batch_size, target_seq_length, source_seq_length)).astype(bool) ## mask tensor
239+
self.mask = np.zeros((self.batch_size, target_seq_length, source_seq_length)).astype(bool) ## mask tensor
240+
self.dev_mask = np.zeros((self.dev_batch_size, target_seq_length, source_seq_length)).astype(bool)
240241
## MLP parameters
241242
Whid1 = random.normal(subkeys[16], (learnable_query_dim, learnable_query_dim)) * sigma
242243
bhid1 = random.normal(subkeys[17], (1, learnable_query_dim)) * sigma
@@ -259,7 +260,7 @@ def __init__(
259260
self.grad_fx = jax.value_and_grad(eval_attention_probe, argnums=0, has_aux=True)
260261
## set up update rule/optimizer
261262
self.optim_params = adam.adam_init(self.probe_params)
262-
self.eta = 0.001
263+
self.eta = 0.0002 #0.001
263264

264265
def process(self, embedding_sequence):
265266
"""
@@ -271,13 +272,14 @@ def process(self, embedding_sequence):
271272
Returns:
272273
probe output scores/probability values
273274
"""
275+
#print(embedding_sequence.shape)
274276
outs, feats = run_attention_probe(
275-
self.probe_params, embedding_sequence, self.mask, self.num_heads, 0.0, use_LN=self.use_LN,
277+
self.probe_params, embedding_sequence, self.dev_mask, self.num_heads, 0.0, use_LN=self.use_LN,
276278
use_softmax=self.use_softmax
277279
)
278280
return outs
279281

280-
def update(self, embedding_sequence, labels):
282+
def update(self, embedding_sequence, labels, dkey=None):
281283
"""
282284
Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally
283285
assigned labels/target vector values.
@@ -290,9 +292,10 @@ def update(self, embedding_sequence, labels):
290292
Returns:
291293
probe output scores/probability values
292294
"""
295+
# TODO: put in dkey to facilitate dropout
293296
## compute partial derivatives / adjustments to probe parameters
294297
outputs, grads = self.grad_fx(
295-
self.probe_params, embedding_sequence, labels, self.mask, self.num_heads, dropout=0., use_LN=self.use_LN,
298+
self.probe_params, embedding_sequence, labels, self.mask, self.num_heads, dropout=0.5, use_LN=self.use_LN,
296299
use_softmax=self.use_softmax
297300
)
298301
loss, predictions = outputs

0 commit comments

Comments
 (0)