Skip to content

Commit 2feeced

Browse files
author
Alexander Ororbia
committed
cleaned up probes
1 parent 84005b5 commit 2feeced

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

ngclearn/utils/analysis/probe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def predict(self, data, batch_size=None):
7979
x_mb = _data[s_ptr:e_ptr, :, :] ## slice out 3D batch tensor
8080
s_ptr = e_ptr
8181
e_ptr += x_mb.shape[0]
82-
y_mu = self.process(x_mb)
82+
y_mu = self.process(x_mb, dkey=None)
8383
Y_mu.append(y_mu)
8484
Y_mu = jnp.concatenate(Y_mu, axis=0)
8585
return Y_mu
@@ -143,8 +143,9 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
143143
s_ptr = e_ptr
144144
e_ptr += x_mb.shape[0]
145145
Ns += x_mb.shape[0]
146+
self.dkey, *subkeys = random.split(self.dkey, 2)
146147

147-
_L, py = self.update(x_mb, y_mb)
148+
_L, py = self.update(x_mb, y_mb, dkey=subkeys[0])
148149
acc = jnp.sum(jnp.equal(jnp.argmax(py, axis=1), jnp.argmax(y_mb, axis=1))) + acc
149150
L = (_L * x_mb.shape[0]) + L ## we remove the batch division from loss w.r.t. x_mb/y_mb
150151

0 commit comments

Comments
 (0)