|
18 | 18 | from .utils import split |
19 | 19 |
|
20 | 20 |
|
21 | | -def measure_error(classifier_j, classifier_k, labeled_data): |
22 | | - pred_j = classifier_j.predict(labeled_data) |
23 | | - pred_k = classifier_k.predict(labeled_data) |
24 | | - same = len([0 for x, y in zip(pred_j, pred_k) if x == y]) |
25 | | - return (len(pred_j) - same) / same |
26 | | - |
27 | | - |
28 | 21 | class TriTraining: |
29 | 22 | """ |
30 | 23 | Zhou, Z. H., & Li, M. (2005). Tri-training: Exploiting unlabeled data |
@@ -203,7 +196,8 @@ def fit(self, samples, y): |
203 | 196 | ): |
204 | 197 | break |
205 | 198 |
|
206 | | - def _check_for_update(self, e_j, ep_j, h_j, l_j, labeled, lp_j, update_j, y): |
| 199 | + def _check_for_update(self, e_j, ep_j, h_j, l_j, |
| 200 | + labeled, lp_j, update_j, y): |
207 | 201 | """ |
208 | 202 | If the update_j flag is True, then we concatenate the labeled data with |
209 | 203 | the new data, and fit the model to the new data |
@@ -244,7 +238,7 @@ def _train_classifier(self, ep_k, h_i, h_j, h_k, labeled, lp_k, u): |
244 | 238 | """ |
245 | 239 | update_k = False |
246 | 240 | l_k = Bunch(data=np.array([]), target=np.array([])) |
247 | | - e_k = measure_error(h_j, h_k, labeled) |
| 241 | + e_k = self.measure_error(h_j, h_k, labeled) |
248 | 242 | if e_k < ep_k: |
249 | 243 | for sample in u: |
250 | 244 | sample_s = sample.reshape(1, -1) |
@@ -286,3 +280,20 @@ def predict(self, samples): |
286 | 280 | labels.append(np.where(count == np.amax(count))[0][0]) |
287 | 281 |
|
288 | 282 | return np.array(labels) |
| 283 | + |
| 284 | + @staticmethod |
| 285 | + def measure_error(classifier_j, classifier_k, labeled_data): |
| 286 | + """ |
| 287 | + It returns the fraction of the time that classifiers j and k disagree on |
| 288 | + the labels of the labeled data |
| 289 | +
|
| 290 | + :param classifier_j: the classifier you want to compare to |
| 291 | + :param classifier_k: the classifier that we want to measure the error of |
| 292 | + :param labeled_data: the labeled data that we're using to train the |
| 293 | + classifiers |
| 294 | + :return: The error rate of the two classifiers. |
| 295 | + """ |
| 296 | + pred_j = classifier_j.predict(labeled_data) |
| 297 | + pred_k = classifier_k.predict(labeled_data) |
| 298 | + same = len([0 for x, y in zip(pred_j, pred_k) if x == y]) |
| 299 | + return (len(pred_j) - same) / same |
0 commit comments