Skip to content

Commit 3a2de99

Browse files
author
Alexander Ororbia
committed
cleaned up probe parent fit routine
1 parent 9ad4ae2 commit 3a2de99

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

ngclearn/utils/analysis/probe.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
7575
patience: number of iterations of improvement (decrease) in loss before early-stopping enacted
7676
7777
Returns:
78-
the output scores/predictions made by this probe
78+
best accuracy found over fitting run
7979
"""
8080
data, labels = dataset
8181
dev_data = dev_labels = None
@@ -97,9 +97,7 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
9797

9898
## run main probe fitting loop
9999
impatience = 0
100-
final_L = 10000.
101100
best_acc = 0.
102-
#Y_mu = []
103101
_Y = None
104102
for ii in range(n_iter):
105103
## shuffle data (to ensure i.i.d. across sequences)
@@ -128,13 +126,9 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
128126
print(f"\r{ii} L = {L / Ns:.3f} Acc = {acc / Ns:.2f} Dev.Acc = {best_acc:.2f}", end="")
129127
else:
130128
print(f"\r{ii} L = {L / Ns:.3f} Acc = {acc / Ns:.2f}", end="")
131-
# if ii == ii-1:
132-
# Y_mu.append(py)
133129
print()
134130
acc = acc / Ns
135-
final_L = L / Ns ## compute current loss over (train) dataset
136-
# if ii == ii - 1:
137-
# Y_mu = jnp.concatenate(Y_mu, axis=0)
131+
L = L / Ns ## compute current loss over (train) dataset
138132

139133
impatience += 1
140134
if dev_data is not None:
@@ -150,5 +144,5 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
150144

151145
if impatience > patience:
152146
break ## execute early stopping
153-
return final_L
147+
return best_acc
154148

0 commit comments

Comments
 (0)