Skip to content

Commit a6e6eb6

Browse files
Docs #194
1 parent f101b76 commit a6e6eb6

File tree

3 files changed

+34
-37
lines changed

3 files changed

+34
-37
lines changed

semisupervised/DemocraticCoLearning.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# @Version: 5.0
77

88
import copy
9+
import warnings
910
from math import sqrt
1011

1112
import numpy as np
@@ -17,21 +18,6 @@
1718
from .utils import split
1819

1920

20-
def check_bounds(wi):
21-
"""
22-
It checks that the lower bound is not less than 0 and the upper bound is not
23-
greater than 1
24-
25-
:param wi: lower and upper mean confidence
26-
:return: the fixed wi.
27-
"""
28-
if wi[0] < 0:
29-
wi[0] = 0
30-
if wi[1] > 1:
31-
wi[1] = 1
32-
return wi
33-
34-
3521
class DemocraticCoLearning:
3622
"""
3723
Democratic Co-Learning Implementation. Based on:
@@ -168,7 +154,7 @@ def fit(self, samples, y):
168154
len(labeled)),
169155
error + self.const * sqrt((error * (1 - error)) /
170156
len(labeled))]
171-
w1 = sum(check_bounds(w1)) / 2
157+
w1 = sum(self.check_bounds(w1)) / 2
172158

173159
for index, proba in enumerate(probas):
174160
c_k = new_labels[index][0]
@@ -193,7 +179,7 @@ def fit(self, samples, y):
193179
len(labeled)),
194180
error + self.const * sqrt((error * (1 - error)) /
195181
len(labeled))]
196-
w2 = sum(check_bounds(w2)) / 2
182+
w2 = sum(self.check_bounds(w2)) / 2
197183

198184
for index, proba in enumerate(probas):
199185
c_k = new_labels[index][0]
@@ -218,7 +204,7 @@ def fit(self, samples, y):
218204
len(labeled)),
219205
error + self.const * sqrt((error * (1 - error)) /
220206
len(labeled))]
221-
w3 = sum(check_bounds(w3)) / 2
207+
w3 = sum(self.check_bounds(w3)) / 2
222208

223209
for index, proba in enumerate(probas):
224210
c_k = new_labels[index][0]
@@ -249,7 +235,7 @@ def fit(self, samples, y):
249235
ci_1 = [
250236
error - self.const * sqrt((error * (1 - error)) / len(pred)),
251237
error + self.const * sqrt((error * (1 - error)) / len(pred))]
252-
ci_1 = check_bounds(ci_1)
238+
ci_1 = self.check_bounds(ci_1)
253239
q_1 = len(pred) * pow((1 - 2 * (e_1 / len(pred))), 2)
254240
e_prime_1 = (1 - (ci_1[0] * len(pred)) / len(pred)) * len(pred)
255241
q_prime_1 = (len(l1_data) + len(pred)) * pow(
@@ -273,7 +259,7 @@ def fit(self, samples, y):
273259
ci_2 = [
274260
error - self.const * sqrt((error * (1 - error)) / len(pred)),
275261
error + self.const * sqrt((error * (1 - error)) / len(pred))]
276-
ci_2 = check_bounds(ci_2)
262+
ci_2 = self.check_bounds(ci_2)
277263
q_2 = len(pred) * pow((1 - 2 * (e_2 / len(pred))), 2)
278264
e_prime_2 = (1 - (ci_2[0] * len(pred)) / len(pred)) * len(pred)
279265
q_prime_2 = (len(l2_data) + len(pred)) * pow(
@@ -297,7 +283,7 @@ def fit(self, samples, y):
297283
ci_3 = [
298284
error - self.const * sqrt((error * (1 - error)) / len(pred)),
299285
error + self.const * sqrt((error * (1 - error)) / len(pred))]
300-
ci_3 = check_bounds(ci_3)
286+
ci_3 = self.check_bounds(ci_3)
301287
q_3 = len(pred) * pow((1 - 2 * (e_3 / len(pred))), 2)
302288
e_prime_3 = (1 - (ci_3[0] * len(pred)) / len(pred)) * len(pred)
303289
q_prime_3 = (len(l3_data) + len(pred)) * pow(
@@ -316,17 +302,17 @@ def fit(self, samples, y):
316302
error = len([0 for p, tar in zip(pred, y) if p != tar]) / len(pred)
317303
w1 = [error - self.const * sqrt((error * (1 - error)) / len(labeled)),
318304
error + self.const * sqrt((error * (1 - error)) / len(labeled))]
319-
self.w1 = sum(check_bounds(w1)) / 2
305+
self.w1 = sum(self.check_bounds(w1)) / 2
320306
pred = self.h2.predict(labeled)
321307
error = len([0 for p, tar in zip(pred, y) if p != tar]) / len(pred)
322308
w2 = [error - self.const * sqrt((error * (1 - error)) / len(labeled)),
323309
error + self.const * sqrt((error * (1 - error)) / len(labeled))]
324-
self.w2 = sum(check_bounds(w2)) / 2
310+
self.w2 = sum(self.check_bounds(w2)) / 2
325311
pred = self.h3.predict(labeled)
326312
error = len([0 for p, tar in zip(pred, y) if p != tar]) / len(pred)
327313
w3 = [error - self.const * sqrt((error * (1 - error)) / len(labeled)),
328314
error + self.const * sqrt((error * (1 - error)) / len(labeled))]
329-
self.w3 = sum(check_bounds(w3)) / 2
315+
self.w3 = sum(self.check_bounds(w3)) / 2
330316

331317
def predict(self, samples):
332318
"""
@@ -361,7 +347,7 @@ def predict(self, samples):
361347
gj[p] += 1
362348
gj_h[2][p] += 1
363349
except IndexError:
364-
breakpoint()
350+
warnings.warn("Retraining the model is advised.")
365351

366352
confidence = [0 for _ in range(self.n_labels)]
367353
for index, j in enumerate(gj):
@@ -386,3 +372,18 @@ def predict(self, samples):
386372
labels.append(np.where(count == np.amax(count))[0][0])
387373

388374
return np.array(labels)
375+
376+
@staticmethod
377+
def check_bounds(wi):
378+
"""
379+
It checks that the lower bound is not less than 0 and the upper bound
380+
is not greater than 1
381+
382+
:param wi: lower and upper mean confidence
383+
:return: the fixed wi.
384+
"""
385+
if wi[0] < 0:
386+
wi[0] = 0
387+
if wi[1] > 1:
388+
wi[1] = 1
389+
return wi

semisupervised/TriTraining.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self, random_state=None,
8080
self.random_state = random_state if random_state is not None else \
8181
np.random.randint(low=0, high=10e5, size=1)[0]
8282

83-
def subsample(self, l_t, s):
83+
def _subsample(self, l_t, s):
8484
np.random.seed(self.random_state)
8585
rng = np.random.default_rng()
8686
data = np.array(l_t['data'])
@@ -144,8 +144,8 @@ def fit(self, samples, y):
144144
if e_j * len(l_j['data']) < ep_j * lp_j:
145145
update_j = True
146146
elif lp_j > e_j / (ep_j - e_j):
147-
l_j = self.subsample(l_j, ceil(((ep_j * lp_j) / e_j)
148-
- 1))
147+
l_j = self._subsample(l_j, ceil(((ep_j * lp_j) / e_j)
148+
- 1))
149149
update_j = True
150150

151151
update_k = False
@@ -171,8 +171,8 @@ def fit(self, samples, y):
171171
if e_k * len(l_k['data']) < ep_k * lp_k:
172172
update_k = True
173173
elif lp_k > e_k / (ep_k - e_k):
174-
l_k = self.subsample(l_k, ceil(((ep_k * lp_k) / e_k)
175-
- 1))
174+
l_k = self._subsample(l_k, ceil(((ep_k * lp_k) / e_k)
175+
- 1))
176176
update_k = True
177177

178178
update_i = False
@@ -198,8 +198,8 @@ def fit(self, samples, y):
198198
if e_i * len(l_i['data']) < ep_i * lp_i:
199199
update_i = True
200200
elif lp_i > e_i / (ep_i - e_i):
201-
l_i = self.subsample(l_i, ceil(((ep_i * lp_i) / e_i)
202-
- 1))
201+
l_i = self._subsample(l_i, ceil(((ep_i * lp_i) / e_i)
202+
- 1))
203203
update_i = True
204204

205205
if update_j:

utils/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
#!/usr/bin/env python
2-
# -*- coding:utf-8 -*-
3-
# @Filename: __init__.py.py
4-
# @Author: Daniel Puente Ramírez
5-
# @Time: 22/12/21 18:05
1+
"""Utils ARFF"""
62

73
from .arff2dataset import arff_data
84

0 commit comments

Comments
 (0)