Skip to content

Commit be837e0

Browse files
committed
fix selection
1 parent 5314c37 commit be837e0

File tree

2 files changed

+239
-7
lines changed

2 files changed

+239
-7
lines changed

src/hidimstat/base_variable_importance.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,23 @@ def selection(
7070
"""
7171
self._check_importance()
7272
if k_best is not None:
73-
if not isinstance(k_best, str) and k_best > self.importances_.shape[1]:
73+
if not isinstance(k_best, str) and k_best > self.importances_.shape[0]:
7474
warnings.warn(
75-
f"k={k_best} is greater than n_features={self.importances_.shape[1]}. "
75+
f"k={k_best} is greater than n_features={self.importances_.shape[0]}. "
7676
"All the features will be returned."
7777
)
78-
assert k_best > 0, "k_best needs to be positive and not null"
78+
if isinstance(k_best, str):
79+
assert k_best == "all"
80+
else:
81+
assert k_best >= 0, "k_best needs to be positive or null"
7982
if percentile is not None:
8083
assert (
81-
0 < percentile and percentile < 100
84+
0 <= percentile and percentile <= 100
8285
), "percentile needs to be between 0 and 100"
8386
if threshold_pvalue is not None:
87+
assert (
88+
self.pvalues_ is not None
89+
), "This method doesn't support a threshold on p-values"
8490
assert (
8591
0 < threshold_pvalue and threshold_pvalue < 1
8692
), "threshold_pvalue needs to be between 0 and 1"
@@ -105,9 +111,9 @@ def selection(
105111
elif percentile == 0:
106112
mask_percentile = np.zeros(len(self.importances_), dtype=bool)
107113
elif percentile is not None:
108-
threshold = np.percentile(self.importances_, 100 - percentile)
109-
mask_percentile = self.importances_ > threshold
110-
ties = np.where(self.importances_ == threshold)[0]
114+
threshold_percentile = np.percentile(self.importances_, 100 - percentile)
115+
mask_percentile = self.importances_ > threshold_percentile
116+
ties = np.where(self.importances_ == threshold_percentile)[0]
111117
if len(ties):
112118
max_feats = int(len(self.importances_) * percentile / 100)
113119
kept_ties = ties[: max_feats - mask_percentile.sum()]
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import pytest
2+
import numpy as np
3+
4+
from hidimstat import BaseVariableImportance
5+
6+
7+
@pytest.fixture
8+
def set_BaseVariableImportance(pvalues, test_score, seed):
9+
nb_features = 100
10+
rng = np.random.RandomState(seed)
11+
vi = BaseVariableImportance()
12+
vi.importances_ = np.arange(nb_features)
13+
rng.shuffle(vi.importances_)
14+
if pvalues or test_score:
15+
vi.pvalues_ = np.sort(rng.rand(nb_features))[vi.importances_]
16+
if test_score:
17+
vi.test_scores_ = []
18+
for i in range(10):
19+
score = np.random.rand(nb_features) * 30
20+
vi.test_scores_.append(score)
21+
for i in range(1, 30):
22+
score = np.random.rand(nb_features) + 1
23+
score[-i:] = np.arange(30 - i, 30) * 2
24+
score[:i] = -np.arange(30 - i, 30)
25+
vi.test_scores_.append(score[vi.importances_])
26+
return vi
27+
28+
29+
@pytest.mark.parametrize(
30+
"pvalues, test_score, seed",
31+
[(False, False, 0), (True, False, 1), (True, True, 2)],
32+
ids=["only importance", "p-value", "test-score"],
33+
)
34+
class TestSelection:
35+
"""Test selection base on importance"""
36+
37+
def test_selection_k_best(self, set_BaseVariableImportance):
38+
"test selection of the k_best"
39+
vi = set_BaseVariableImportance
40+
true_value = vi.importances_ >= 95
41+
selection = vi.selection(k_best=5)
42+
np.testing.assert_array_equal(true_value, selection)
43+
44+
def test_selection_k_best_all(self, set_BaseVariableImportance):
45+
"test selection to all base on string"
46+
vi = set_BaseVariableImportance
47+
true_value = np.ones_like(vi.importances_, dtype=bool)
48+
selection = vi.selection(k_best="all")
49+
np.testing.assert_array_equal(true_value, selection)
50+
51+
def test_selection_k_best_none(self, set_BaseVariableImportance):
52+
"test selection when there none"
53+
vi = set_BaseVariableImportance
54+
true_value = np.zeros_like(vi.importances_, dtype=bool)
55+
selection = vi.selection(k_best=0)
56+
np.testing.assert_array_equal(true_value, selection)
57+
58+
def test_selection_percentile(self, set_BaseVariableImportance):
59+
"test selection bae on percentile"
60+
vi = set_BaseVariableImportance
61+
true_value = vi.importances_ >= 50
62+
selection = vi.selection(percentile=50)
63+
np.testing.assert_array_equal(true_value, selection)
64+
65+
def test_selection_percentile_all(self, set_BaseVariableImportance):
66+
"test selection when percentile is 100"
67+
vi = set_BaseVariableImportance
68+
true_value = np.ones_like(vi.importances_, dtype=bool)
69+
selection = vi.selection(percentile=100)
70+
np.testing.assert_array_equal(true_value, selection)
71+
72+
def test_selection_percentile_none(self, set_BaseVariableImportance):
73+
"test selection when percentile is 0"
74+
vi = set_BaseVariableImportance
75+
true_value = np.zeros_like(vi.importances_, dtype=bool)
76+
selection = vi.selection(percentile=0)
77+
np.testing.assert_array_equal(true_value, selection)
78+
79+
def test_selection_percentile_threshols_value(self, set_BaseVariableImportance):
80+
"test selection when percentile when the percentile equal on value"
81+
vi = set_BaseVariableImportance
82+
mask = np.ones_like(vi.importances_, dtype=bool)
83+
mask[np.where(vi.importances_ == 99)] = False
84+
vi.importances_ = vi.importances_[mask]
85+
true_value = vi.importances_ >= 50
86+
selection = vi.selection(percentile=50)
87+
np.testing.assert_array_equal(true_value, selection)
88+
89+
def test_selection_threshold(self, set_BaseVariableImportance):
90+
"test threshold on importance"
91+
vi = set_BaseVariableImportance
92+
true_value = vi.importances_ < 5
93+
selection = vi.selection(threshold=5)
94+
np.testing.assert_array_equal(true_value, selection)
95+
96+
def test_selection_threshold_pvalue(self, set_BaseVariableImportance):
97+
"test threshold vbse on pvalues"
98+
vi = set_BaseVariableImportance
99+
if vi.pvalues_ is not None:
100+
true_value = vi.importances_ < 5
101+
print(vi.pvalues_)
102+
selection = vi.selection(
103+
threshold_pvalue=vi.pvalues_[np.argsort(vi.importances_)[5]]
104+
)
105+
np.testing.assert_array_equal(true_value, selection)
106+
107+
108+
@pytest.mark.parametrize(
109+
"pvalues, test_score, seed", [(True, True, 10)], ids=["default"]
110+
)
111+
class TestSelectionFDR:
112+
"""Test selection base on fdr"""
113+
114+
def test_selection_fdr_default(self, set_BaseVariableImportance):
115+
"test selection of the default"
116+
vi = set_BaseVariableImportance
117+
true_value = vi.importances_ >= 85
118+
selection = vi.selection_fdr(0.2)
119+
np.testing.assert_array_equal(true_value, selection)
120+
121+
def test_selection_fdr_adaptation(self, set_BaseVariableImportance):
122+
"test selection of the adaptation"
123+
vi = set_BaseVariableImportance
124+
true_value = vi.importances_ >= 85
125+
selection = vi.selection_fdr(0.2, adaptive_aggregation=True)
126+
np.testing.assert_array_equal(true_value, selection)
127+
128+
def test_selection_fdr_bhy(self, set_BaseVariableImportance):
129+
"test selection of the adaptation"
130+
vi = set_BaseVariableImportance
131+
true_value = vi.importances_ >= 85
132+
selection = vi.selection_fdr(0.8, fdr_control="bhy")
133+
np.testing.assert_array_equal(true_value, selection)
134+
135+
def test_selection_fdr_ebh(self, set_BaseVariableImportance):
136+
"test selection of the adaptation"
137+
vi = set_BaseVariableImportance
138+
true_value = vi.importances_ >= 2
139+
selection = vi.selection_fdr(0.037, fdr_control="ebh", evalues=True)
140+
np.testing.assert_array_equal(true_value, selection)
141+
142+
143+
@pytest.mark.parametrize(
144+
"pvalues, test_score, seed",
145+
[(False, False, 0), (True, False, 0), (True, True, 0)],
146+
ids=["only importance", "p-value", "test-score"],
147+
)
148+
class TestBVIExceptions:
149+
"""Test class for BVI Exception"""
150+
151+
def test_not_fit(self, pvalues, test_score, seed):
152+
"test detection unfit"
153+
vi = BaseVariableImportance()
154+
with pytest.raises(
155+
ValueError,
156+
match="The importances need to be called before calling this method",
157+
):
158+
vi._check_importance()
159+
with pytest.raises(
160+
ValueError,
161+
match="The importances need to be called before calling this method",
162+
):
163+
vi.selection()
164+
with pytest.raises(
165+
ValueError,
166+
match="The importances need to be called before calling this method",
167+
):
168+
vi.selection_fdr(0.1)
169+
170+
def test_selection_k_best(self, set_BaseVariableImportance):
171+
"test selection k_best wrong"
172+
vi = set_BaseVariableImportance
173+
with pytest.raises(AssertionError, match="k_best needs to be positive or null"):
174+
vi.selection(k_best=-10)
175+
with pytest.warns(Warning, match="k=1000 is greater than n_features="):
176+
vi.selection(k_best=1000)
177+
178+
def test_selection_percentile(self, set_BaseVariableImportance):
179+
"test selection percentile wrong"
180+
vi = set_BaseVariableImportance
181+
with pytest.raises(
182+
AssertionError, match="percentile needs to be between 0 and 100"
183+
):
184+
vi.selection(percentile=-1)
185+
with pytest.raises(
186+
AssertionError, match="percentile needs to be between 0 and 100"
187+
):
188+
vi.selection(percentile=102)
189+
190+
def test_selection_threshold(self, set_BaseVariableImportance):
191+
"test selection threshold wrong"
192+
vi = set_BaseVariableImportance
193+
if vi.pvalues_ is None:
194+
with pytest.raises(
195+
AssertionError,
196+
match="This method doesn't support a threshold on p-values",
197+
):
198+
vi.selection(threshold_pvalue=-1)
199+
else:
200+
with pytest.raises(
201+
AssertionError, match="threshold_pvalue needs to be between 0 and 1"
202+
):
203+
vi.selection(threshold_pvalue=-1)
204+
with pytest.raises(
205+
AssertionError, match="threshold_pvalue needs to be between 0 and 1"
206+
):
207+
vi.selection(threshold_pvalue=1.1)
208+
209+
def test_selection_fdr_fdr_control(self, set_BaseVariableImportance):
210+
"test selection fdr_control wrong"
211+
vi = set_BaseVariableImportance
212+
if vi.test_scores_ is None:
213+
with pytest.raises(
214+
AssertionError,
215+
match="this method doesn't support selection base on FDR",
216+
):
217+
vi.selection_fdr(fdr=0.1)
218+
else:
219+
with pytest.raises(
220+
AssertionError, match="for e-value, the fdr control need to be 'ebh'"
221+
):
222+
vi.selection_fdr(fdr=0.1, evalues=True)
223+
with pytest.raises(
224+
AssertionError, match="for p-value, the fdr control can't be 'ebh'"
225+
):
226+
vi.selection_fdr(fdr=0.1, fdr_control="ebh", evalues=False)

0 commit comments

Comments
 (0)