Skip to content

Commit 155d830

Browse files
author
Alexander Ororbia
committed
cleaned up probe parent fit routine
1 parent 3a2de99 commit 155d830

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

ngclearn/utils/analysis/probe.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,12 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
122122
_L, py = self.update(x_mb, y_mb)
123123
acc = jnp.sum(jnp.equal(jnp.argmax(py, axis=1), jnp.argmax(y_mb, axis=1))) + acc
124124
L = (_L * x_mb.shape[0]) + L ## we remove the batch division from loss w.r.t. x_mb/y_mb
125-
if dev_data is not None:
126-
print(f"\r{ii} L = {L / Ns:.3f} Acc = {acc / Ns:.2f} Dev.Acc = {best_acc:.2f}", end="")
127-
else:
128-
print(f"\r{ii} L = {L / Ns:.3f} Acc = {acc / Ns:.2f}", end="")
129-
print()
125+
126+
if dev_data is not None:
127+
print(f"\r{ii} L = {L / Ns:.3f} Acc = {acc / Ns:.2f} Dev.Acc = {best_acc:.2f}", end="")
128+
else:
129+
print(f"\r{ii} L = {L / Ns:.3f} Acc = {acc / Ns:.2f}", end="")
130+
130131
acc = acc / Ns
131132
L = L / Ns ## compute current loss over (train) dataset
132133

@@ -144,5 +145,6 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
144145

145146
if impatience > patience:
146147
break ## execute early stopping
148+
print()
147149
return best_acc
148150

0 commit comments

Comments
 (0)