Skip to content

Commit 364f0b0

Browse files
Updated SemiSupervised Library to v2.0 #133
1 parent 77b6b72 commit 364f0b0

File tree

6 files changed

+214
-161
lines changed

6 files changed

+214
-161
lines changed

semisupervised/CoTraining.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
# @Filename: CoTraining.py
44
# @Author: Daniel Puente Ramírez
55
# @Time: 22/12/21 09:27
6-
# @Version: 2.0
6+
# @Version: 3.0
77

88
from math import ceil, floor
99

1010
import numpy as np
1111
from sklearn.naive_bayes import GaussianNB
1212
from sklearn.preprocessing import LabelEncoder
13+
from .utils import split
1314

1415

1516
class CoTraining:
1617

17-
def __init__(self, p=1, n=3, k=30, u=75, random_state=42):
18+
def __init__(self, p=1, n=3, k=30, u=75, random_state=None):
1819
self.p = p
1920
self.n = n
2021
self.k = k
@@ -24,29 +25,31 @@ def __init__(self, p=1, n=3, k=30, u=75, random_state=42):
2425
self.h1 = GaussianNB()
2526
self.h2 = GaussianNB()
2627

27-
def fit(self, L, U, y):
28-
if len(L) != len(y):
28+
def fit(self, samples, y):
29+
labeled, u, y = split(samples, y)
30+
31+
if len(labeled) != len(y):
2932
raise ValueError(
3033
f'The dimension of the labeled data must be the same as the '
31-
f'number of labels given. {len(L)} != {len(y)}'
34+
f'number of labels given. {len(labeled)} != {len(y)}'
3235
)
3336

3437
le = LabelEncoder()
3538
le.fit(y)
3639
y = le.transform(y)
3740
tot = self.n + self.p
3841

39-
self.size_x1 = ceil(len(L[0]) / 2)
42+
self.size_x1 = ceil(len(labeled[0]) / 2)
4043

4144
rng = np.random.default_rng()
42-
u_random_index = rng.choice(len(U), size=floor(self.u),
45+
u_random_index = rng.choice(len(u), size=floor(self.u),
4346
replace=False, shuffle=False)
4447

45-
u_prime = U[u_random_index]
48+
u_prime = u[u_random_index]
4649
u1, u2 = np.array_split(u_prime, 2, axis=1)
4750

4851
for _ in range(self.k):
49-
x1, x2 = np.array_split(L, 2, axis=1)
52+
x1, x2 = np.array_split(labeled, 2, axis=1)
5053

5154
self.h1.fit(x1, y)
5255
self.h2.fit(x2, y)
@@ -74,29 +77,29 @@ def fit(self, L, U, y):
7477
u1_new_samples = np.concatenate((u1_samples, u2_x1_samples), axis=1)
7578
u2_new_samples = np.concatenate((u2_samples, u1_x2_samples), axis=1)
7679
u_new = np.concatenate((u1_new_samples, u2_new_samples))
77-
L = np.concatenate((L, u_new))
80+
labeled = np.concatenate((labeled, u_new))
7881
y_new = np.array([x[0] for x in top_h1] + [x[0] for x in top_h2])
7982
y = np.concatenate((y, y_new))
8083

81-
old_indexes = np.array([x[2] for x in top_h1] + [x[2] for x in \
84+
old_indexes = np.array([x[2] for x in top_h1] + [x[2] for x in
8285
top_h2], int)
8386
u_prime = np.delete(u_prime, old_indexes, axis=0)
8487

85-
U = np.delete(U, u_random_index, axis=0)
88+
u = np.delete(u, u_random_index, axis=0)
8689
try:
87-
u_random_index = rng.choice(len(U),
90+
u_random_index = rng.choice(len(u),
8891
size=2 * self.p + 2 * self.n,
8992
replace=False, shuffle=False)
9093
except ValueError:
9194
print(f'The model was incorrectly parametrized, k is to big.')
9295
try:
93-
u_prime = np.concatenate((u_prime, U[u_random_index]))
96+
u_prime = np.concatenate((u_prime, u[u_random_index]))
9497
except IndexError:
9598
print('The model was incorrectly parametrized, there are not '
9699
'enough unlabeled samples.')
97100

98-
def predict(self, X):
99-
x1, x2 = np.array_split(X, 2, axis=1)
101+
def predict(self, samples):
102+
x1, x2 = np.array_split(samples, 2, axis=1)
100103
pred1, pred_proba1 = self.h1.predict(x1), self.h1.predict_proba(x1)
101104
pred2, pred_proba2 = self.h2.predict(x2), self.h2.predict_proba(x2)
102105
labels = []

0 commit comments

Comments
 (0)