Skip to content

Commit 9ad4ae2

Browse files
author
Alexander Ororbia
committed
cleaned up probe parent fit routine
1 parent b688c6c commit 9ad4ae2

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

ngclearn/utils/analysis/linear_probe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
self.use_LN = use_LN
8585
self.l2_decay = 0.0001
8686
self.l1_decay = 0.000025
87+
# eta = 0.05 for SGD, batch_size=2000
8788

8889
## set up classifier
8990
flat_input_dim = input_dim * source_seq_length

ngclearn/utils/analysis/probe.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ class Probe():
1111
1212
"""
1313
def __init__(
14-
self, dkey, batch_size=4, **kwargs
14+
self, dkey, batch_size=1, dev_batch_size=1, **kwargs
1515
):
1616
#dkey, *subkeys = random.split(dkey, 3)
1717
self.dkey = dkey
1818
self.batch_size = batch_size
19+
self.dev_batch_size = dev_batch_size
1920

2021
def process(self, embeddings):
2122
predictions = None
@@ -25,24 +26,29 @@ def update(self, embeddings, labels):
2526
L = predictions = None
2627
return L, predictions
2728

28-
def predict(self, data):
29+
def predict(self, data, batch_size=None):
2930
"""
3031
Runs this probe's inference scheme over a pool of data.
3132
3233
Args:
3334
data: a dataset or design tensor/matrix containing encoding vector sequences; shape (N, T, embed_dim) or (N, embed_dim)
3435
36+
batch_size: optional batch-size argument (Default: None, will use training batch size)
37+
3538
Returns:
3639
the output scores/predictions made by this probe
3740
"""
41+
_batch_size = batch_size
42+
if _batch_size is None:
43+
_batch_size = self.batch_size
3844
_data = data
3945
if len(_data.shape) < 3:
4046
_data = jnp.expand_dims(_data, axis=1)
4147

4248
n_samples, seq_len, dim = _data.shape
43-
n_batches = int(n_samples / self.batch_size)
49+
n_batches = int(n_samples / _batch_size)
4450
s_ptr = 0
45-
e_ptr = self.batch_size
51+
e_ptr = _batch_size
4652
Y_mu = []
4753
for b in range(n_batches):
4854
x_mb = _data[s_ptr:e_ptr, :, :] ## slice out 3D batch tensor
@@ -132,7 +138,7 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
132138

133139
impatience += 1
134140
if dev_data is not None:
135-
Ymu = self.predict(dev_data)
141+
Ymu = self.predict(dev_data, batch_size=self.dev_batch_size)
136142
acc = jnp.sum(jnp.equal(jnp.argmax(Ymu, axis=1), jnp.argmax(dev_labels, axis=1))) / (dev_labels.shape[0] * 1.)
137143
if acc > best_acc:
138144
best_acc = acc

0 commit comments

Comments
 (0)