Skip to content

Commit 27fd9bf

Browse files
author
Alexander Ororbia
committed
minor tweak to attn probe
1 parent dc8c127 commit 27fd9bf

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def cross_attention(params: tuple, x1: jax.Array, x2: jax.Array, mask: jax.Array
7474
return attention @ Wout + bout # (B, T, Dq)
7575

7676
@bind(jax.jit, static_argnums=[3, 4, 5, 6, 7])
77-
def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=True, use_softmax=True):
77+
def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False, use_softmax=True):
7878
"""
7979
Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).
8080
@@ -138,8 +138,8 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
138138
outs = jax.nn.softmax(outs)
139139
return outs, features
140140

141-
@bind(jax.jit, static_argnums=[4, 5, 6, 7])
142-
def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_softmax=True):
141+
@bind(jax.jit, static_argnums=[4, 5, 6, 7, 8])
142+
def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False, use_softmax=True):
143143
"""
144144
Runs and evaluates the nonlinear attentive probe given a paired set of encoding vectors and externally assigned
145145
labels/regression targets.
@@ -165,7 +165,7 @@ def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout:
165165
current loss value, output scores/probabilities
166166
"""
167167
# encodings: (B, hw, dim)
168-
outs, _ = run_attention_probe(params, encodings, mask, n_heads, dropout, use_LN, use_softmax)
168+
outs, _ = run_attention_probe(params, encodings, mask, n_heads, dropout, use_LN, use_LN_input, use_softmax)
169169
if use_softmax: ## Multinoulli log likelihood for 1-of-K predictions
170170
L = -jnp.mean(jnp.sum(jnp.log(outs.clip(min=1e-5)) * labels, axis=1, keepdims=True))
171171
else: ## MSE for real-valued outputs
@@ -206,7 +206,7 @@ class AttentiveProbe(Probe):
206206
"""
207207
def __init__(
208208
self, dkey, source_seq_length, input_dim, out_dim, num_heads=8, attn_dim=64,
209-
target_seq_length=1, learnable_query_dim=32, batch_size=1, hid_dim=32, use_LN=True, use_LN_input=True, use_softmax=True, **kwargs
209+
target_seq_length=1, learnable_query_dim=32, batch_size=1, hid_dim=32, use_LN=True, use_LN_input=False, use_softmax=True, **kwargs
210210
):
211211
super().__init__(dkey, batch_size, **kwargs)
212212
assert attn_dim % num_heads == 0, f"`attn_dim` must be divisible by `num_heads`. Got {attn_dim} and {num_heads}."
@@ -288,7 +288,7 @@ def process(self, embedding_sequence):
288288
#print(embedding_sequence.shape)
289289
outs, feats = run_attention_probe(
290290
self.probe_params, embedding_sequence, self.dev_mask, self.num_heads, 0.0, use_LN=self.use_LN,
291-
use_softmax=self.use_softmax
291+
use_LN_input=self.use_LN_input, use_softmax=self.use_softmax
292292
)
293293
return outs
294294

@@ -310,7 +310,7 @@ def update(self, embedding_sequence, labels, dkey=None):
310310
# NOTE: Viet: Change back to 0.0 for now for the code to run
311311
outputs, grads = self.grad_fx(
312312
self.probe_params, embedding_sequence, labels, self.mask, self.num_heads, dropout=0.0, use_LN=self.use_LN,
313-
use_softmax=self.use_softmax
313+
use_LN_input=self.use_LN_input, use_softmax=self.use_softmax
314314
)
315315
loss, predictions = outputs
316316
## adjust parameters of probe

0 commit comments

Comments
 (0)