Skip to content

Commit 4bc6545

Browse files
committed
RandomForest: Common mixin for regression and classification
1 parent ec09e9a commit 4bc6545

File tree

4 files changed

+74
-16
lines changed

4 files changed

+74
-16
lines changed

Orange/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,17 @@ def tree(self):
289289
sklearn.tree._tree.Tree
290290
"""
291291
raise NotImplementedError()
292+
293+
294+
class RandomForest:
295+
"""Interface for random forest models
296+
"""
297+
298+
@property
299+
def trees(self):
300+
"""Return a list of Trees in the forest
301+
302+
Returns
303+
-------
304+
List[Tree]
305+
"""

Orange/classification/random_forest.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import sklearn.ensemble as skl_ensemble
22

3+
from Orange.base import RandomForest
34
from Orange.classification import SklLearner, SklModel
5+
from Orange.classification.tree import TreeClassifier
46
from Orange.data import Variable, DiscreteVariable
57
from Orange.preprocess.score import LearnerScorer
68

@@ -16,8 +18,19 @@ def score(self, data):
1618
return model.skl_model.feature_importances_
1719

1820

19-
class RandomForestClassifier(SklModel):
20-
pass
21+
class RandomForestClassifier(SklModel, RandomForest):
22+
@property
23+
def trees(self):
24+
def wrap(tree, i):
25+
t = TreeClassifier(tree)
26+
t.domain = self.domain
27+
t.supports_multiclass = self.supports_multiclass
28+
t.name = "{} - tree {}".format(self.name, i)
29+
t.original_domain = self.original_domain
30+
return t
31+
32+
return [wrap(tree, i)
33+
for i, tree in enumerate(self.skl_model.estimators_)]
2134

2235

2336
class RandomForestLearner(SklLearner, _FeatureScorerMixin):

Orange/regression/random_forest.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import sklearn.ensemble as skl_ensemble
2+
3+
from Orange.base import RandomForest
24
from Orange.regression import SklLearner, SklModel
35
from Orange.data import Variable, ContinuousVariable
46
from Orange.preprocess.score import LearnerScorer
7+
from Orange.regression.tree import TreeRegressor
58

69
__all__ = ["RandomForestRegressionLearner"]
710

@@ -15,8 +18,19 @@ def score(self, data):
1518
return model.skl_model.feature_importances_
1619

1720

18-
class RandomForestRegressor(SklModel):
19-
pass
21+
class RandomForestRegressor(SklModel, RandomForest):
22+
@property
23+
def trees(self):
24+
def wrap(tree, i):
25+
t = TreeRegressor(tree)
26+
t.domain = self.domain
27+
t.supports_multiclass = self.supports_multiclass
28+
t.name = "{} - tree {}".format(self.name, i)
29+
t.original_domain = self.original_domain
30+
return t
31+
32+
return [wrap(tree, i)
33+
for i, tree in enumerate(self.skl_model.estimators_)]
2034

2135

2236
class RandomForestRegressionLearner(SklLearner, _FeatureScorerMixin):

Orange/tests/test_random_forest.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from Orange.regression import RandomForestRegressionLearner
1010
from Orange.tests import test_filename
1111

12+
1213
class RandomForestTest(unittest.TestCase):
1314
@classmethod
1415
def setUpClass(cls):
1516
cls.iris = Table('iris')
16-
cls.house = Table('housing')
17+
cls.housing = Table('housing')
1718

1819
def test_RandomForest(self):
1920
forest = RandomForestLearner()
@@ -43,28 +44,28 @@ def test_predict_numpy(self):
4344

4445
def test_RandomForestRegression(self):
4546
forest = RandomForestRegressionLearner()
46-
results = CrossValidation(self.house, [forest], k=10)
47+
results = CrossValidation(self.housing, [forest], k=10)
4748
_ = RMSE(results)
4849

4950
def test_predict_single_instance_reg(self):
5051
forest = RandomForestRegressionLearner()
51-
model = forest(self.house)
52-
for ins in self.house:
52+
model = forest(self.housing)
53+
for ins in self.housing:
5354
pred = model(ins)
5455
self.assertGreater(pred, 0)
5556

5657
def test_predict_table_reg(self):
5758
forest = RandomForestRegressionLearner()
58-
model = forest(self.house)
59-
pred = model(self.house)
60-
self.assertEqual(len(self.house), len(pred))
59+
model = forest(self.housing)
60+
pred = model(self.housing)
61+
self.assertEqual(len(self.housing), len(pred))
6162
self.assertGreater(all(pred), 0)
6263

6364
def test_predict_numpy_reg(self):
6465
forest = RandomForestRegressionLearner()
65-
model = forest(self.house)
66-
pred = model(self.house.X)
67-
self.assertEqual(len(self.house), len(pred))
66+
model = forest(self.housing)
67+
pred = model(self.housing.X)
68+
self.assertEqual(len(self.housing), len(pred))
6869
self.assertGreater(all(pred), 0)
6970

7071
def test_classification_scorer(self):
@@ -78,9 +79,9 @@ def test_classification_scorer(self):
7879

7980
def test_regression_scorer(self):
8081
learner = RandomForestRegressionLearner()
81-
scores = learner.score_data(self.house)
82+
scores = learner.score_data(self.housing)
8283
self.assertEqual(['LSTAT', 'RM'],
83-
sorted([self.house.domain.attributes[i].name
84+
sorted([self.housing.domain.attributes[i].name
8485
for i in np.argsort(scores[0])[-2:]]))
8586

8687
def test_scorer_feature(self):
@@ -92,3 +93,19 @@ def test_scorer_feature(self):
9293
np.random.seed(42)
9394
score = learner.score_data(data, attr)
9495
np.testing.assert_array_almost_equal(score, scores[:, i])
96+
97+
def test_get_classification_trees(self):
98+
n = 5
99+
forest = RandomForestLearner(n_estimators=n)
100+
model = forest(self.iris)
101+
self.assertEqual(len(model.trees), n)
102+
tree = model.trees[0]
103+
self.assertEqual(tree(self.iris[0]), 0)
104+
105+
def test_get_regression_trees(self):
106+
n = 5
107+
forest = RandomForestRegressionLearner(n_estimators=n)
108+
model = forest(self.housing)
109+
self.assertEqual(len(model.trees), n)
110+
tree = model.trees[0]
111+
tree(self.housing[0])

0 commit comments

Comments
 (0)