Skip to content

Commit f2a7b15

Browse files
Minor refactors #207
1 parent cc2e1df commit f2a7b15

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

instance_selection/_ENN.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ def __init__(self, nearest_neighbors=3, power_parameter=2):
4747

4848
def _neighs(self, s_samples, s_targets, index, removed):
4949
"""
50-
_neighs() takes in the samples and targets, the index of the sample to
51-
be removed, and the number of samples already removed. It returns the
52-
sample to be removed, its target, the targets of the samples not yet
53-
removed, the samples not yet removed, and the indices of the nearest
54-
neighbors of the sample to be removed.
50+
The function takes in the samples and targets, the index of the
51+
sample to be removed, and the number of samples already removed. It
52+
returns the sample to be removed, its target, the targets of the
53+
samples not yet removed, the samples not yet removed, and the
54+
indices of the nearest neighbors of the sample to be removed.
5555
5656
:param s_samples: the samples that are being used to train the model
5757
:param s_targets: the targets of the samples

semisupervised/TriTraining.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,6 @@
1818
from .utils import split
1919

2020

21-
def measure_error(classifier_j, classifier_k, labeled_data):
22-
pred_j = classifier_j.predict(labeled_data)
23-
pred_k = classifier_k.predict(labeled_data)
24-
same = len([0 for x, y in zip(pred_j, pred_k) if x == y])
25-
return (len(pred_j) - same) / same
26-
27-
2821
class TriTraining:
2922
"""
3023
Zhou, Z. H., & Li, M. (2005). Tri-training: Exploiting unlabeled data
@@ -203,7 +196,8 @@ def fit(self, samples, y):
203196
):
204197
break
205198

206-
def _check_for_update(self, e_j, ep_j, h_j, l_j, labeled, lp_j, update_j, y):
199+
def _check_for_update(self, e_j, ep_j, h_j, l_j,
200+
labeled, lp_j, update_j, y):
207201
"""
208202
If the update_j flag is True, then we concatenate the labeled data with
209203
the new data, and fit the model to the new data
@@ -244,7 +238,7 @@ def _train_classifier(self, ep_k, h_i, h_j, h_k, labeled, lp_k, u):
244238
"""
245239
update_k = False
246240
l_k = Bunch(data=np.array([]), target=np.array([]))
247-
e_k = measure_error(h_j, h_k, labeled)
241+
e_k = self.measure_error(h_j, h_k, labeled)
248242
if e_k < ep_k:
249243
for sample in u:
250244
sample_s = sample.reshape(1, -1)
@@ -286,3 +280,20 @@ def predict(self, samples):
286280
labels.append(np.where(count == np.amax(count))[0][0])
287281

288282
return np.array(labels)
283+
284+
@staticmethod
285+
def measure_error(classifier_j, classifier_k, labeled_data):
286+
"""
287+
It returns the fraction of the time that classifiers j and k disagree on
288+
the labels of the labeled data
289+
290+
:param classifier_j: the classifier you want to compare to
291+
:param classifier_k: the classifier that we want to measure the error of
292+
:param labeled_data: the labeled data that we're using to train the
293+
classifiers
294+
:return: The error rate of the two classifiers.
295+
"""
296+
pred_j = classifier_j.predict(labeled_data)
297+
pred_k = classifier_k.predict(labeled_data)
298+
same = len([0 for x, y in zip(pred_j, pred_k) if x == y])
299+
return (len(pred_j) - same) / same

0 commit comments

Comments
 (0)