1212from sklearn .ensemble import RandomForestClassifier
1313from sklearn .metrics import accuracy_score
1414
15+
1516@pytest .fixture (scope = "module" )
1617def Xy ():
1718 X = pd .read_csv ('./tests/sample_data/sample_fm_enc.csv' )
1819 y = X .pop ('label' )
1920 return X , y
2021
22+
2123@pytest .fixture (scope = "module" )
2224def fit_dend (Xy ):
2325 X , y = Xy
@@ -43,25 +45,32 @@ def test_dend_fit(fit_dend):
4345 assert selector .edges is not None
4446 assert selector .graphs is not None
4547
48+
4649def test_dend_set_params (fit_dend ):
4750 threshlist = fit_dend .threshlist
4851 fit_dend .set_params (threshlist = None )
4952
50- assert fit_dend .threshlist == None
53+ assert fit_dend .threshlist is None
5154 fit_dend .threshlist = threshlist
5255
56+
5357def test_dend_features_at_step (fit_dend ):
5458 assert len (fit_dend .features_at_step (48 )) == 79
5559
60+
5661def test_dend_find_set_of_size (fit_dend , capsys ):
5762 assert fit_dend .find_set_of_size (80 ) == 6
58-
63+
64+
5965def test_dend_score_at_point (Xy , fit_dend ):
6066 X , y = Xy
61- scores , fit_model = fit_dend .score_at_point (X , y , RandomForestClassifier (random_state = 0 ), accuracy_score , 2 )
67+ scores , fit_model = fit_dend .score_at_point (X , y ,
68+ RandomForestClassifier (random_state = 0 ),
69+ accuracy_score , 2 )
6270 assert len (scores ) == 1
6371 assert scores [0 ] - .866666 < .00001
6472
73+
6574def test_dend_shuffle_all (fit_dend ):
6675 keys_1 = set (fit_dend .graphs [1 ].keys ())
6776 fit_dend .shuffle_all_representatives ()
@@ -72,42 +81,46 @@ def test_dend_shuffle_all(fit_dend):
7281def test_dend_shuffle_score_at_point (Xy , fit_dend ):
7382 X , y = Xy
7483 keys_1 = set (fit_dend .graphs [1 ].keys ())
75- scores , _ = fit_dend .shuffle_score_at_point (X , y , RandomForestClassifier (),
76- accuracy_score , 2 , 2 )
84+ scores , _ = fit_dend .shuffle_score_at_point (X , y , RandomForestClassifier (),
85+ accuracy_score , 2 , 2 )
7786 assert set (fit_dend .graphs [1 ].keys ()) != keys_1
7887 assert len (scores ) == 2
7988
89+
8090def test_dend_transform (Xy , fit_dend , capsys ):
8191 X , y = Xy
8292 X_new_1 = fit_dend .transform (X , 99 )
8393 out1 , _ = capsys .readouterr ()
8494 X_new_2 = fit_dend .transform (X , 50 )
8595 out2 , _ = capsys .readouterr ()
86-
96+
8797 assert X_new_1 .shape [1 ] == int (out1 [10 :12 ])
8898 assert X_new_2 .shape [1 ] == int (out2 [- 3 :- 1 ])
8999
100+
90101def test_dend_plot (fit_dend ):
91102 show (dendrogram (fit_dend ), static = True )
92103 show (dendrogram (fit_dend ))
93104
105+
94106def test_build_edges (capsys ):
95107 fake_sel = selection .Dendrogram ()
96108 fake_sel .adj = np .asarray (range (501 ))
97109 fake_sel ._build_edges (None )
98110
99111 output , _ = capsys .readouterr ()
100112 split_output = output .split ('\n ' )
101-
113+
102114 real_line_1 = 'Calculating more than 500 graphs'
103115 real_line_2 = 'You can pass max_threshes as a kwarg to Dendrogram'
104116 assert split_output [0 ] == real_line_1
105117 assert split_output [1 ] == real_line_2
106118
119+
107120def test_build_graphs_exit ():
108121 fake_sel = selection .Dendrogram ()
109122 fake_sel .threshlist = [1 , 2 ]
110- fake_sel .edges = [[(0 , 1 )], [(1 , 2 ), (0 , 1 )]]
123+ fake_sel .edges = [[(0 , 1 )], [(1 , 2 ), (0 , 1 )]]
111124 fake_sel .graphs = [{0 : {0 , 1 }, 2 : {2 }}, {0 : {0 , 1 , 2 }}]
112125 fake_sel ._build_graphs ()
113126
0 commit comments