66
77import random
88
9- import pytest
109import numpy as np
1110import pandas as pd
11+ import pytest
1212from sklearn .datasets import load_iris as load_digits
1313from sklearn .model_selection import train_test_split
14+ from sklearn .neighbors import KNeighborsClassifier
1415
16+ from instance_selection import ENN
1517from semisupervised import STDPNF , CoTraining , TriTraining , \
1618 DemocraticCoLearning
1719
1820
19- def to_dataframe (y ):
20- if not isinstance (y , pd .DataFrame ):
21- return pd .DataFrame (y )
22- return y
23-
24-
2521@pytest .fixture
2622def digits_dataset_ss ():
2723 x , y = load_digits (return_X_y = True , as_frame = True )
@@ -36,7 +32,7 @@ def digits_dataset_ss():
3632 y_train = pd .DataFrame (y_train )
3733 y_test = pd .DataFrame (y_test )
3834 li = list (set (range (x_train .shape [0 ])))
39- unlabeled = random .sample (li , int (x_train .shape [0 ] * 0.3 ))
35+ unlabeled = random .sample (li , int (x_train .shape [0 ] * 0.55 ))
4036 y_train .loc [unlabeled ] = - 1
4137
4238 return x_train , x_test , y_train , y_test , opt_labels
@@ -57,19 +53,68 @@ def base(x_train, x_test, y_train, y_test, opt_labels, algorithm, params=None):
5753
5854def test_co_training (digits_dataset_ss ):
5955 x_train , x_test , y_train , y_test , opt_labels = digits_dataset_ss
60- base (x_train , x_test , y_train , y_test , opt_labels , CoTraining )
56+ base (x_train , x_test , y_train , y_test , opt_labels , CoTraining ,
57+ {'p' : 1 , 'n' : 3 , 'k' : 1 , 'u' : 7 })
58+ base (x_train , x_test , y_train , y_test , opt_labels , CoTraining ,
59+ {'p' : 1 , 'n' : 3 , 'k' : 1 , 'u' : 7 ,
60+ 'c1' : KNeighborsClassifier , 'c1_params' : {'n_neighbors' : 3 },
61+ 'c2' : KNeighborsClassifier })
62+
63+ with pytest .raises (ValueError ):
64+ base (x_train , x_test , y_train , y_test , opt_labels , CoTraining )
65+
66+ with pytest .raises (ValueError ):
67+ base (x_train , x_test , y_train , y_test , opt_labels , CoTraining ,
68+ {'p' : 1 , 'n' : 3 , 'k' : 100 , 'u' : 7 })
69+
70+ with pytest .raises (ValueError ):
71+ base (x_train , x_test , y_train , y_test , opt_labels , CoTraining ,
72+ {'p' : 5 , 'n' : 5 , 'k' : 100 , 'u' : 15 })
6173
6274
6375def test_tri_training (digits_dataset_ss ):
6476 x_train , x_test , y_train , y_test , opt_labels = digits_dataset_ss
65- base (x_train , x_test , y_train , y_test , opt_labels , TriTraining )
77+ base (x_train , x_test , y_train , y_test , opt_labels , TriTraining ,
78+ {'c1' : KNeighborsClassifier , 'c1_params' : {'n_neighbors' : 3 },
79+ 'c2' : KNeighborsClassifier })
6680
6781
6882def test_demo_co_learning (digits_dataset_ss ):
6983 x_train , x_test , y_train , y_test , opt_labels = digits_dataset_ss
7084 base (x_train , x_test , y_train , y_test , opt_labels , DemocraticCoLearning )
85+ base (x_train , x_test , y_train , y_test , opt_labels , DemocraticCoLearning ,
86+ {'c1' : KNeighborsClassifier , 'c1_params' : {'n_neighbors' : 3 },
87+ 'c2' : KNeighborsClassifier })
7188
7289
7390def test_density_peaks (digits_dataset_ss ):
7491 x_train , x_test , y_train , y_test , opt_labels = digits_dataset_ss
7592 base (x_train , x_test , y_train , y_test , opt_labels , STDPNF )
93+
94+
95+ def test_density_peaks_filtering (digits_dataset_ss ):
96+ x_train , x_test , y_train , y_test , opt_labels = digits_dataset_ss
97+ with pytest .raises (AttributeError ):
98+ base (x_train , x_test , y_train , y_test , opt_labels , STDPNF ,
99+ {'filtering' : True })
100+ base (x_train , x_test , y_train , y_test , opt_labels , STDPNF ,
101+ {'filtering' : True , 'filter_method' : 'ENANE' })
102+
103+ base (x_train , x_test , y_train , y_test , opt_labels , STDPNF ,
104+ {'filtering' : True , 'filter_method' : ENN , 'dc' : 'auto' ,
105+ 'classifier' : KNeighborsClassifier })
106+
107+
108+ def test_different_len (digits_dataset_ss ):
109+ x , _ , y , _ , _ = digits_dataset_ss
110+ co = CoTraining ()
111+ tri = TriTraining ()
112+ demo_co = DemocraticCoLearning ()
113+ stdpnf = STDPNF ()
114+
115+ models = [co , tri , demo_co , stdpnf ]
116+ y = y [:- 1 ]
117+
118+ for model in models :
119+ with pytest .raises (ValueError ):
120+ model .fit (x , y )
0 commit comments