66# @Version: 5.0
77
88import copy
9+ import warnings
910from math import sqrt
1011
1112import numpy as np
1718from .utils import split
1819
1920
20- def check_bounds (wi ):
21- """
22- It checks that the lower bound is not less than 0 and the upper bound is not
23- greater than 1
24-
25- :param wi: lower and upper mean confidence
26- :return: the fixed wi.
27- """
28- if wi [0 ] < 0 :
29- wi [0 ] = 0
30- if wi [1 ] > 1 :
31- wi [1 ] = 1
32- return wi
33-
34-
3521class DemocraticCoLearning :
3622 """
3723 Democratic Co-Learning Implementation. Based on:
@@ -168,7 +154,7 @@ def fit(self, samples, y):
168154 len (labeled )),
169155 error + self .const * sqrt ((error * (1 - error )) /
170156 len (labeled ))]
171- w1 = sum (check_bounds (w1 )) / 2
157+ w1 = sum (self . check_bounds (w1 )) / 2
172158
173159 for index , proba in enumerate (probas ):
174160 c_k = new_labels [index ][0 ]
@@ -193,7 +179,7 @@ def fit(self, samples, y):
193179 len (labeled )),
194180 error + self .const * sqrt ((error * (1 - error )) /
195181 len (labeled ))]
196- w2 = sum (check_bounds (w2 )) / 2
182+ w2 = sum (self . check_bounds (w2 )) / 2
197183
198184 for index , proba in enumerate (probas ):
199185 c_k = new_labels [index ][0 ]
@@ -218,7 +204,7 @@ def fit(self, samples, y):
218204 len (labeled )),
219205 error + self .const * sqrt ((error * (1 - error )) /
220206 len (labeled ))]
221- w3 = sum (check_bounds (w3 )) / 2
207+ w3 = sum (self . check_bounds (w3 )) / 2
222208
223209 for index , proba in enumerate (probas ):
224210 c_k = new_labels [index ][0 ]
@@ -249,7 +235,7 @@ def fit(self, samples, y):
249235 ci_1 = [
250236 error - self .const * sqrt ((error * (1 - error )) / len (pred )),
251237 error + self .const * sqrt ((error * (1 - error )) / len (pred ))]
252- ci_1 = check_bounds (ci_1 )
238+ ci_1 = self . check_bounds (ci_1 )
253239 q_1 = len (pred ) * pow ((1 - 2 * (e_1 / len (pred ))), 2 )
254240 e_prime_1 = (1 - (ci_1 [0 ] * len (pred )) / len (pred )) * len (pred )
255241 q_prime_1 = (len (l1_data ) + len (pred )) * pow (
@@ -273,7 +259,7 @@ def fit(self, samples, y):
273259 ci_2 = [
274260 error - self .const * sqrt ((error * (1 - error )) / len (pred )),
275261 error + self .const * sqrt ((error * (1 - error )) / len (pred ))]
276- ci_2 = check_bounds (ci_2 )
262+ ci_2 = self . check_bounds (ci_2 )
277263 q_2 = len (pred ) * pow ((1 - 2 * (e_2 / len (pred ))), 2 )
278264 e_prime_2 = (1 - (ci_2 [0 ] * len (pred )) / len (pred )) * len (pred )
279265 q_prime_2 = (len (l2_data ) + len (pred )) * pow (
@@ -297,7 +283,7 @@ def fit(self, samples, y):
297283 ci_3 = [
298284 error - self .const * sqrt ((error * (1 - error )) / len (pred )),
299285 error + self .const * sqrt ((error * (1 - error )) / len (pred ))]
300- ci_3 = check_bounds (ci_3 )
286+ ci_3 = self . check_bounds (ci_3 )
301287 q_3 = len (pred ) * pow ((1 - 2 * (e_3 / len (pred ))), 2 )
302288 e_prime_3 = (1 - (ci_3 [0 ] * len (pred )) / len (pred )) * len (pred )
303289 q_prime_3 = (len (l3_data ) + len (pred )) * pow (
@@ -316,17 +302,17 @@ def fit(self, samples, y):
316302 error = len ([0 for p , tar in zip (pred , y ) if p != tar ]) / len (pred )
317303 w1 = [error - self .const * sqrt ((error * (1 - error )) / len (labeled )),
318304 error + self .const * sqrt ((error * (1 - error )) / len (labeled ))]
319- self .w1 = sum (check_bounds (w1 )) / 2
305+ self .w1 = sum (self . check_bounds (w1 )) / 2
320306 pred = self .h2 .predict (labeled )
321307 error = len ([0 for p , tar in zip (pred , y ) if p != tar ]) / len (pred )
322308 w2 = [error - self .const * sqrt ((error * (1 - error )) / len (labeled )),
323309 error + self .const * sqrt ((error * (1 - error )) / len (labeled ))]
324- self .w2 = sum (check_bounds (w2 )) / 2
310+ self .w2 = sum (self . check_bounds (w2 )) / 2
325311 pred = self .h3 .predict (labeled )
326312 error = len ([0 for p , tar in zip (pred , y ) if p != tar ]) / len (pred )
327313 w3 = [error - self .const * sqrt ((error * (1 - error )) / len (labeled )),
328314 error + self .const * sqrt ((error * (1 - error )) / len (labeled ))]
329- self .w3 = sum (check_bounds (w3 )) / 2
315+ self .w3 = sum (self . check_bounds (w3 )) / 2
330316
331317 def predict (self , samples ):
332318 """
@@ -361,7 +347,7 @@ def predict(self, samples):
361347 gj [p ] += 1
362348 gj_h [2 ][p ] += 1
363349 except IndexError :
364- breakpoint ( )
350+ warnings . warn ( "Retraining the model is advised." )
365351
366352 confidence = [0 for _ in range (self .n_labels )]
367353 for index , j in enumerate (gj ):
@@ -386,3 +372,18 @@ def predict(self, samples):
386372 labels .append (np .where (count == np .amax (count ))[0 ][0 ])
387373
388374 return np .array (labels )
375+
376+ @staticmethod
377+ def check_bounds (wi ):
378+ """
379+ It checks that the lower bound is not less than 0 and the upper bound
380+ is not greater than 1
381+
382+ :param wi: lower and upper mean confidence
383+ :return: the fixed wi.
384+ """
385+ if wi [0 ] < 0 :
386+ wi [0 ] = 0
387+ if wi [1 ] > 1 :
388+ wi [1 ] = 1
389+ return wi
0 commit comments