@@ -75,7 +75,7 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
7575 patience: number of iterations of improvement (decrease) in loss before early-stopping enacted
7676
7777 Returns:
78- the output scores/predictions made by this probe
78+ best accuracy found over fitting run
7979 """
8080 data , labels = dataset
8181 dev_data = dev_labels = None
@@ -97,9 +97,7 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
9797
9898 ## run main probe fitting loop
9999 impatience = 0
100- final_L = 10000.
101100 best_acc = 0.
102- #Y_mu = []
103101 _Y = None
104102 for ii in range (n_iter ):
105103 ## shuffle data (to ensure i.i.d. across sequences)
@@ -128,13 +126,9 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
128126 print (f"\r { ii } L = { L / Ns :.3f} Acc = { acc / Ns :.2f} Dev.Acc = { best_acc :.2f} " , end = "" )
129127 else :
130128 print (f"\r { ii } L = { L / Ns :.3f} Acc = { acc / Ns :.2f} " , end = "" )
131- # if ii == ii-1:
132- # Y_mu.append(py)
133129 print ()
134130 acc = acc / Ns
135- final_L = L / Ns ## compute current loss over (train) dataset
136- # if ii == ii - 1:
137- # Y_mu = jnp.concatenate(Y_mu, axis=0)
131+ L = L / Ns ## compute current loss over (train) dataset
138132
139133 impatience += 1
140134 if dev_data is not None :
@@ -150,5 +144,5 @@ def fit(self, dataset, dev_dataset=None, n_iter=50, patience=20):
150144
151145 if impatience > patience :
152146 break ## execute early stopping
153- return final_L
147+ return best_acc
154148
0 commit comments