Skip to content

Commit 2230b7c

Browse files
authored
Merge pull request #1420 from astaric/tree-interface
Tree: Common mixin for regression and classification tree
2 parents e0145bd + 4bc6545 commit 2230b7c

File tree

7 files changed

+122
-21
lines changed

7 files changed

+122
-21
lines changed

Orange/base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,40 @@ def supports_weights(self):
266266
"""Indicates whether this learner supports weighted instances.
267267
"""
268268
return 'sample_weight' in self.__wraps__.fit.__code__.co_varnames
269+
270+
271+
class Tree:
272+
"""Interface for tree based models.
273+
274+
Defines members needed for drawing of the tree.
275+
"""
276+
277+
#: Domain of data the tree was built from
278+
domain = None
279+
280+
#: Data the tree was built from (Optional)
281+
instances = None
282+
283+
@property
284+
def tree(self):
285+
"""Return underlying tree representation
286+
287+
Returns
288+
-------
289+
sklearn.tree._tree.Tree
290+
"""
291+
raise NotImplementedError()
292+
293+
294+
class RandomForest:
295+
"""Interface for random forest models
296+
"""
297+
298+
@property
299+
def trees(self):
300+
"""Return a list of Trees in the forest
301+
302+
Returns
303+
-------
304+
List[Tree]
305+
"""

Orange/classification/random_forest.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import sklearn.ensemble as skl_ensemble
22

3+
from Orange.base import RandomForest
34
from Orange.classification import SklLearner, SklModel
5+
from Orange.classification.tree import TreeClassifier
46
from Orange.data import Variable, DiscreteVariable
57
from Orange.preprocess.score import LearnerScorer
68

@@ -16,8 +18,19 @@ def score(self, data):
1618
return model.skl_model.feature_importances_
1719

1820

19-
class RandomForestClassifier(SklModel):
20-
pass
21+
class RandomForestClassifier(SklModel, RandomForest):
22+
@property
23+
def trees(self):
24+
def wrap(tree, i):
25+
t = TreeClassifier(tree)
26+
t.domain = self.domain
27+
t.supports_multiclass = self.supports_multiclass
28+
t.name = "{} - tree {}".format(self.name, i)
29+
t.original_domain = self.original_domain
30+
return t
31+
32+
return [wrap(tree, i)
33+
for i, tree in enumerate(self.skl_model.estimators_)]
2134

2235

2336
class RandomForestLearner(SklLearner, _FeatureScorerMixin):

Orange/classification/tree.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import sklearn.tree as skl_tree
2+
3+
from Orange.base import Tree
24
from Orange.classification import SklLearner, SklModel
35
from Orange.preprocess import (RemoveNaNClasses, Continuize,
46
RemoveNaNColumns, SklImpute)
57

68
__all__ = ["TreeLearner"]
79

810

9-
class TreeClassifier(SklModel):
10-
pass
11+
class TreeClassifier(SklModel, Tree):
12+
@property
13+
def tree(self):
14+
return self.skl_model.tree_
1115

1216

1317
class TreeLearner(SklLearner):

Orange/regression/random_forest.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import sklearn.ensemble as skl_ensemble
2+
3+
from Orange.base import RandomForest
24
from Orange.regression import SklLearner, SklModel
35
from Orange.data import Variable, ContinuousVariable
46
from Orange.preprocess.score import LearnerScorer
7+
from Orange.regression.tree import TreeRegressor
58

69
__all__ = ["RandomForestRegressionLearner"]
710

@@ -15,8 +18,19 @@ def score(self, data):
1518
return model.skl_model.feature_importances_
1619

1720

18-
class RandomForestRegressor(SklModel):
19-
pass
21+
class RandomForestRegressor(SklModel, RandomForest):
22+
@property
23+
def trees(self):
24+
def wrap(tree, i):
25+
t = TreeRegressor(tree)
26+
t.domain = self.domain
27+
t.supports_multiclass = self.supports_multiclass
28+
t.name = "{} - tree {}".format(self.name, i)
29+
t.original_domain = self.original_domain
30+
return t
31+
32+
return [wrap(tree, i)
33+
for i, tree in enumerate(self.skl_model.estimators_)]
2034

2135

2236
class RandomForestRegressionLearner(SklLearner, _FeatureScorerMixin):

Orange/regression/tree.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import sklearn.tree as skl_tree
2+
3+
from Orange.base import Tree
24
from Orange.regression import SklLearner, SklModel
35
from Orange.preprocess import Continuize, RemoveNaNColumns, SklImpute
46

57
__all__ = ["TreeRegressionLearner"]
68

79

8-
class TreeRegressor(SklModel):
9-
pass
10+
class TreeRegressor(SklModel, Tree):
11+
@property
12+
def tree(self):
13+
return self.skl_model.tree_
1014

1115

1216
class TreeRegressionLearner(SklLearner):

Orange/tests/test_random_forest.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from Orange.regression import RandomForestRegressionLearner
1010
from Orange.tests import test_filename
1111

12+
1213
class RandomForestTest(unittest.TestCase):
1314
@classmethod
1415
def setUpClass(cls):
1516
cls.iris = Table('iris')
16-
cls.house = Table('housing')
17+
cls.housing = Table('housing')
1718

1819
def test_RandomForest(self):
1920
forest = RandomForestLearner()
@@ -43,28 +44,28 @@ def test_predict_numpy(self):
4344

4445
def test_RandomForestRegression(self):
4546
forest = RandomForestRegressionLearner()
46-
results = CrossValidation(self.house, [forest], k=10)
47+
results = CrossValidation(self.housing, [forest], k=10)
4748
_ = RMSE(results)
4849

4950
def test_predict_single_instance_reg(self):
5051
forest = RandomForestRegressionLearner()
51-
model = forest(self.house)
52-
for ins in self.house:
52+
model = forest(self.housing)
53+
for ins in self.housing:
5354
pred = model(ins)
5455
self.assertGreater(pred, 0)
5556

5657
def test_predict_table_reg(self):
5758
forest = RandomForestRegressionLearner()
58-
model = forest(self.house)
59-
pred = model(self.house)
60-
self.assertEqual(len(self.house), len(pred))
59+
model = forest(self.housing)
60+
pred = model(self.housing)
61+
self.assertEqual(len(self.housing), len(pred))
6162
self.assertGreater(all(pred), 0)
6263

6364
def test_predict_numpy_reg(self):
6465
forest = RandomForestRegressionLearner()
65-
model = forest(self.house)
66-
pred = model(self.house.X)
67-
self.assertEqual(len(self.house), len(pred))
66+
model = forest(self.housing)
67+
pred = model(self.housing.X)
68+
self.assertEqual(len(self.housing), len(pred))
6869
self.assertGreater(all(pred), 0)
6970

7071
def test_classification_scorer(self):
@@ -78,9 +79,9 @@ def test_classification_scorer(self):
7879

7980
def test_regression_scorer(self):
8081
learner = RandomForestRegressionLearner()
81-
scores = learner.score_data(self.house)
82+
scores = learner.score_data(self.housing)
8283
self.assertEqual(['LSTAT', 'RM'],
83-
sorted([self.house.domain.attributes[i].name
84+
sorted([self.housing.domain.attributes[i].name
8485
for i in np.argsort(scores[0])[-2:]]))
8586

8687
def test_scorer_feature(self):
@@ -92,3 +93,19 @@ def test_scorer_feature(self):
9293
np.random.seed(42)
9394
score = learner.score_data(data, attr)
9495
np.testing.assert_array_almost_equal(score, scores[:, i])
96+
97+
def test_get_classification_trees(self):
98+
n = 5
99+
forest = RandomForestLearner(n_estimators=n)
100+
model = forest(self.iris)
101+
self.assertEqual(len(model.trees), n)
102+
tree = model.trees[0]
103+
self.assertEqual(tree(self.iris[0]), 0)
104+
105+
def test_get_regression_trees(self):
106+
n = 5
107+
forest = RandomForestRegressionLearner(n_estimators=n)
108+
model = forest(self.housing)
109+
self.assertEqual(len(model.trees), n)
110+
tree = model.trees[0]
111+
tree(self.housing[0])

Orange/tests/test_tree.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
import sklearn.tree as skl_tree
8-
from sklearn.tree._tree import TREE_LEAF
8+
from sklearn.tree._tree import TREE_LEAF, Tree
99

1010
from Orange.data import Table
1111
from Orange.classification import TreeLearner
@@ -27,6 +27,18 @@ def test_regression(self):
2727
pred = model(table)
2828
self.assertTrue(np.all(table.Y.flatten() == pred))
2929

30+
def test_get_tree_classification(self):
31+
table = Table('iris')
32+
learn = TreeLearner()
33+
clf = learn(table)
34+
self.assertIsInstance(clf.tree, Tree)
35+
36+
def test_get_tree_regression(self):
37+
table = Table('housing')
38+
learn = TreeRegressionLearner()
39+
clf = learn(table)
40+
self.assertIsInstance(clf.tree, Tree)
41+
3042

3143
class TestDecisionTreeClassifier(unittest.TestCase):
3244
@classmethod

0 commit comments

Comments
 (0)