Skip to content

Commit c5cf826

Browse files
committed
Move Stacking to core Orange
1 parent 1e8a286 commit c5cf826

File tree

5 files changed

+250
-0
lines changed

5 files changed

+250
-0
lines changed

Orange/ensembles/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1+
# pylint: disable=wildcard-import
2+
13
from .ada_boost import *
4+
from .stack import *

Orange/ensembles/stack.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import numpy as np
2+
3+
from Orange.base import Learner, Model
4+
from Orange.modelling import Fitter
5+
from Orange.classification import LogisticRegressionLearner
6+
from Orange.classification.base_classification import LearnerClassification
7+
from Orange.data import Domain, ContinuousVariable, Table
8+
from Orange.evaluation import CrossValidation
9+
from Orange.regression import RidgeRegressionLearner
10+
from Orange.regression.base_regression import LearnerRegression
11+
12+
13+
__all__ = ['StackedLearner', 'StackedClassificationLearner',
14+
'StackedRegressionLearner', 'StackedFitter']
15+
16+
17+
class StackedModel(Model):
18+
def __init__(self, models, aggregate, use_prob=True, domain=None):
19+
super().__init__(domain=domain)
20+
self.models = models
21+
self.aggregate = aggregate
22+
self.use_prob = use_prob
23+
24+
def predict_storage(self, data):
25+
if self.use_prob:
26+
probs = [m(data, Model.Probs) for m in self.models]
27+
X = np.hstack(probs)
28+
else:
29+
pred = [m(data) for m in self.models]
30+
X = np.column_stack(pred)
31+
Y = np.repeat(np.nan, X.shape[0])
32+
stacked_data = data.transform(self.aggregate.domain)
33+
stacked_data.X = X
34+
stacked_data.Y = Y
35+
return self.aggregate(
36+
stacked_data, Model.ValueProbs if self.use_prob else Model.Value)
37+
38+
39+
class StackedLearner(Learner):
40+
"""
41+
Constructs a stacked model by fitting an aggregator
42+
over the results of base models.
43+
44+
K-fold cross-validation is used to get predictions of the base learners
45+
and fit the aggregator to obtain a stacked model.
46+
47+
Args:
48+
learners (list):
49+
list of `Learner`s used for base models
50+
51+
aggregate (Learner):
52+
Learner used to fit the meta model, aggregating predictions
53+
of base models
54+
55+
k (int):
56+
number of folds for cross-validation
57+
58+
Returns:
59+
instance of StackedModel
60+
"""
61+
62+
__returns__ = StackedModel
63+
64+
def __init__(self, learners, aggregate, k=5, preprocessors=None):
65+
super().__init__(preprocessors=preprocessors)
66+
self.learners = learners
67+
self.aggregate = aggregate
68+
self.k = k
69+
self.params = vars()
70+
71+
def fit_storage(self, data):
72+
res = CrossValidation(data, self.learners, k=self.k)
73+
if data.domain.class_var.is_discrete:
74+
X = np.hstack(res.probabilities)
75+
use_prob = True
76+
else:
77+
X = res.predicted.T
78+
use_prob = False
79+
dom = Domain([ContinuousVariable('f{}'.format(i + 1))
80+
for i in range(X.shape[1])],
81+
data.domain.class_var)
82+
stacked_data = data.transform(dom)
83+
stacked_data.X = X
84+
stacked_data.Y = res.actual
85+
models = [l(data) for l in self.learners]
86+
aggregate_model = self.aggregate(stacked_data)
87+
return StackedModel(models, aggregate_model, use_prob=use_prob,
88+
domain=data.domain)
89+
90+
91+
class StackedClassificationLearner(StackedLearner, LearnerClassification):
92+
"""
93+
Subclass of StackedLearner intended for classification tasks.
94+
95+
Same as the super class, but has a default
96+
classification-specific aggregator (`LogisticRegressionLearner`).
97+
"""
98+
99+
def __init__(self, learners, aggregate=LogisticRegressionLearner(), k=5,
100+
preprocessors=None):
101+
super().__init__(learners, aggregate, k=k, preprocessors=preprocessors)
102+
103+
104+
class StackedRegressionLearner(StackedLearner, LearnerRegression):
105+
"""
106+
Subclass of StackedLearner intended for regression tasks.
107+
108+
Same as the super class, but has a default
109+
regression-specific aggregator (`RidgeRegressionLearner`).
110+
"""
111+
def __init__(self, learners, aggregate=RidgeRegressionLearner(), k=5,
112+
preprocessors=None):
113+
super().__init__(learners, aggregate, k=k, preprocessors=preprocessors)
114+
115+
116+
class StackedFitter(Fitter):
117+
__fits__ = {'classification': StackedClassificationLearner,
118+
'regression': StackedRegressionLearner}
119+
120+
def __init__(self, learners, **kwargs):
121+
kwargs['learners'] = learners
122+
super().__init__(**kwargs)
123+
124+
125+
if __name__ == '__main__':
126+
import Orange
127+
iris = Table('iris')
128+
knn = Orange.modelling.KNNLearner()
129+
tree = Orange.modelling.TreeLearner()
130+
sl = StackedFitter([tree, knn])
131+
m = sl(iris[::2])
132+
print(m(iris[1::2], Model.Value))
133+
134+
housing = Table('housing')
135+
sl = StackedFitter([tree, knn])
136+
m = sl(housing[::2])
137+
print(list(zip(housing[1:10:2].Y, m(housing[1:10:2], Model.Value))))

Orange/tests/test_stack.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import unittest
2+
3+
from Orange.data import Table
4+
from Orange.ensembles.stack import StackedFitter
5+
from Orange.evaluation import CA, CrossValidation, MSE
6+
from Orange.modelling import KNNLearner, TreeLearner
7+
8+
9+
class TestStackedFitter(unittest.TestCase):
10+
@classmethod
11+
def setUpClass(cls):
12+
cls.iris = Table('iris')
13+
cls.housing = Table('housing')
14+
15+
def test_classification(self):
16+
sf = StackedFitter([TreeLearner(), KNNLearner()])
17+
results = CrossValidation(self.iris, [sf], k=3)
18+
ca = CA(results)
19+
self.assertGreater(ca, 0.9)
20+
21+
def test_regression(self):
22+
sf = StackedFitter([TreeLearner(), KNNLearner()])
23+
results = CrossValidation(self.housing[:50],
24+
[sf, TreeLearner(), KNNLearner()], k=3,
25+
random_state=0)
26+
mse = MSE()(results)
27+
self.assertLess(mse[0], mse[1])
28+
self.assertLess(mse[0], mse[2])
Lines changed: 13 additions & 0 deletions
Loading

Orange/widgets/model/owstack.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from collections import OrderedDict
2+
3+
from Orange.base import Learner
4+
from Orange.data import Table
5+
from Orange.ensembles.stack import StackedFitter
6+
from Orange.widgets.settings import Setting
7+
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
8+
from Orange.widgets.widget import Input
9+
10+
11+
class OWStackedLearner(OWBaseLearner):
12+
name = "Stacking"
13+
description = "Stack multiple models."
14+
icon = "icons/Stacking.svg"
15+
priority = 100
16+
17+
LEARNER = StackedFitter
18+
19+
learner_name = Setting("Stack")
20+
21+
class Inputs(OWBaseLearner.Inputs):
22+
learners = Input("Learners", Learner, multiple=True)
23+
aggregate = Input("Aggregate", Learner)
24+
25+
def __init__(self):
26+
self.learners = OrderedDict()
27+
self.aggregate = None
28+
super().__init__()
29+
30+
def add_main_layout(self):
31+
pass
32+
33+
@Inputs.learners
34+
def set_learners(self, learner, id):
35+
if id in self.learners and learner is None:
36+
del self.learners[id]
37+
elif learner is not None:
38+
self.learners[id] = learner
39+
self.apply()
40+
41+
@Inputs.aggregate
42+
def set_aggregate(self, aggregate):
43+
self.aggregate = aggregate
44+
self.apply()
45+
46+
def create_learner(self):
47+
if not self.learners:
48+
return None
49+
return self.LEARNER(
50+
tuple(self.learners.values()), aggregate=self.aggregate,
51+
preprocessors=self.preprocessors)
52+
53+
def get_learner_parameters(self):
54+
return (("Base learners", [l.name for l in self.learners.values()]),
55+
("Aggregator",
56+
self.aggregate.name if self.aggregate else 'default'))
57+
58+
59+
if __name__ == "__main__":
60+
import sys
61+
from AnyQt.QtWidgets import QApplication
62+
63+
a = QApplication(sys.argv)
64+
ow = OWStackedLearner()
65+
d = Table(sys.argv[1] if len(sys.argv) > 1 else 'iris')
66+
ow.set_data(d)
67+
ow.show()
68+
a.exec_()
69+
ow.saveSettings()

0 commit comments

Comments
 (0)