Skip to content

Commit 0793ff0

Browse files
committed
added unit test with relevant/irrelevant features
1 parent a246d67 commit 0793ff0

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

boruta/test/unit_tests.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,46 @@
11
import unittest
22
from boruta import BorutaPy
33
from sklearn.ensemble import RandomForestClassifier
4+
import numpy as np
45

56

67
class BorutaTestCases(unittest.TestCase):
78

89
def test_get_tree_num(self):
910
rfc = RandomForestClassifier(max_depth=10)
1011
bt = BorutaPy(rfc)
11-
self.assertEqual(bt._get_tree_num(10),44,"Tree Est. Math Fail")
12-
self.assertEqual(bt._get_tree_num(100),141,"Tree Est. Math Fail")
12+
self.assertEqual(bt._get_tree_num(10), 44, "Tree Est. Math Fail")
13+
self.assertEqual(bt._get_tree_num(100), 141, "Tree Est. Math Fail")
14+
15+
def test_if_boruta_extracts_relevant_features(self):
16+
np.random.seed(42)
17+
y = np.random.binomial(1, 0.5, 1000)
18+
X = np.zeros((1000, 10))
19+
20+
z = y - np.random.binomial(1, 0.1, 1000) + np.random.binomial(1, 0.1, 1000)
21+
z[z == -1] = 0
22+
z[z == 2] = 1
23+
24+
# 5 relevant features
25+
X[:, 0] = z
26+
X[:, 1] = y * np.abs(np.random.normal(0, 1, 1000)) + np.random.normal(0, 0.1, 1000)
27+
X[:, 2] = y + np.random.normal(0, 1, 1000)
28+
X[:, 3] = y ** 2 + np.random.normal(0, 1, 1000)
29+
X[:, 4] = np.sqrt(y) + np.random.binomial(2, 0.1, 1000)
30+
31+
# 5 irrelevant features
32+
X[:, 5] = np.random.normal(0, 1, 1000)
33+
X[:, 6] = np.random.poisson(1, 1000)
34+
X[:, 7] = np.random.binomial(1, 0.3, 1000)
35+
X[:, 8] = np.random.normal(0, 1, 1000)
36+
X[:, 9] = np.random.poisson(1, 1000)
37+
38+
rfc = RandomForestClassifier()
39+
bt = BorutaPy(rfc)
40+
bt.fit(X, y)
1341

42+
# make sure that only all the relevant features are returned
43+
self.assertItemsEqual(range(5), list(np.where(bt.support_)[0]))
1444

1545

1646
if __name__ == '__main__':

0 commit comments

Comments
 (0)