forked from biolab/orange3
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathowlearnerwidget.py
More file actions
237 lines (186 loc) · 7.7 KB
/
owlearnerwidget.py
File metadata and controls
237 lines (186 loc) · 7.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import numpy as np
from PyQt4.QtCore import QTimer, Qt
from Orange.classification.base_classification import LearnerClassification
from Orange.data import Table
from Orange.preprocess.preprocess import Preprocess
from Orange.widgets import gui
from Orange.widgets.settings import Setting
from Orange.widgets.utils.sql import check_sql_input
from Orange.widgets.widget import OWWidget, WidgetMetaClass
class DefaultWidgetChannelsMetaClass(WidgetMetaClass):
"""Metaclass that adds default inputs and outputs objects.
"""
REQUIRED_ATTRIBUTES = []
def __new__(mcls, name, bases, attrib):
# check whether it is abstract class
if attrib.get('name', False):
# Ensure all needed attributes are present
if not all(attr in attrib for attr in mcls.REQUIRED_ATTRIBUTES):
raise AttributeError("'{name}' must have '{attrs}' attributes"
.format(name=name, attrs="', '".join(mcls.REQUIRED_ATTRIBUTES)))
attrib['outputs'] = mcls.update_channel(
mcls.default_outputs(attrib),
attrib.get('outputs', [])
)
attrib['inputs'] = mcls.update_channel(
mcls.default_inputs(attrib),
attrib.get('inputs', [])
)
mcls.add_extra_attributes(name, attrib)
return super().__new__(mcls, name, bases, attrib)
@classmethod
def default_inputs(cls, attrib):
return []
@classmethod
def default_outputs(cls, attrib):
return []
@classmethod
def update_channel(cls, channel, items):
item_names = set(item[0] for item in channel)
for item in items:
if not item[0] in item_names:
channel.append(item)
return channel
@classmethod
def add_extra_attributes(cls, name, attrib):
return attrib
class OWBaseLearnerMeta(DefaultWidgetChannelsMetaClass):
"""Metaclass that adds default inputs (table, preprocess) and
outputs (learner, model) for learner widgets.
"""
REQUIRED_ATTRIBUTES = ['LEARNER']
@classmethod
def default_inputs(cls, attrib):
return [("Data", Table, "set_data"), ("Preprocessor", Preprocess, "set_preprocessor")]
@classmethod
def default_outputs(cls, attrib):
learner_class = attrib['LEARNER']
if issubclass(learner_class, LearnerClassification):
model_name = 'Classifier'
else:
model_name = 'Predictor'
attrib['OUTPUT_MODEL_NAME'] = model_name
return [("Learner", learner_class),
(model_name, attrib['LEARNER'].__returns__)]
@classmethod
def add_extra_attributes(cls, name, attrib):
if 'learner_name' not in attrib:
attrib['learner_name'] = Setting(attrib['name'])
return attrib
class OWBaseLearner(OWWidget, metaclass=OWBaseLearnerMeta):
"""Abstract widget for classification/regression learners.
Notes:
All learner widgets should define learner class LEARNER.
LEARNER should have __returns__ attribute.
Overwrite `create_learner`, `add_main_layout` and
`get_learner_parameters` in case LEARNER has extra parameters.
"""
LEARNER = None
want_main_area = False
resizing_enabled = False
auto_apply = Setting(True)
DATA_ERROR_ID = 1
OUTDATED_LEARNER_WARNING_ID = 2
def __init__(self):
super().__init__()
self.data = None
self.valid_data = False
self.learner = None
self.model = None
self.preprocessors = None
self.outdated_settings = False
self.setup_layout()
QTimer.singleShot(0, self.apply)
def create_learner(self):
"""Creates a learner with current configuration.
Returns:
Leaner: an instance of Orange.base.learner subclass.
"""
return self.LEARNER(preprocessors=self.preprocessors)
def get_learner_parameters(self):
"""Creates an `OrderedDict` or a sequence of pairs with current model
configuration.
Returns:
OrderedDict or List: (option, value) pairs or dict
"""
return []
def set_preprocessor(self, preprocessor):
"""Add user-set preprocessors before the default, mandatory ones"""
self.preprocessors = ((preprocessor,) if preprocessor else ()) + tuple(self.LEARNER.preprocessors)
self.apply()
@check_sql_input
def set_data(self, data):
"""Set the input train data set."""
self.error(self.DATA_ERROR_ID)
self.data = data
if data is not None and data.domain.class_var is None:
self.error(self.DATA_ERROR_ID, "Data has no target variable")
self.data = None
self.update_model()
def apply(self):
"""Applies leaner and sends new model."""
self.update_learner()
self.update_model()
def update_learner(self):
self.learner = self.create_learner()
self.learner.name = self.learner_name
self.send("Learner", self.learner)
self.outdated_settings = False
self.warning(self.OUTDATED_LEARNER_WARNING_ID)
def update_model(self):
if self.check_data():
self.model = self.learner(self.data)
self.model.name = self.learner_name
self.model.instances = self.data
self.valid_data = True
self.send(self.OUTPUT_MODEL_NAME, self.model)
def check_data(self):
self.valid_data = False
if self.data is not None and self.learner is not None:
self.error(self.DATA_ERROR_ID)
if not self.learner.check_learner_adequacy(self.data.domain):
self.error(self.DATA_ERROR_ID, self.learner.learner_adequacy_err_msg)
elif len(np.unique(self.data.Y)) < 2:
self.error(self.DATA_ERROR_ID,
"Data contains a single target value. "
"There is nothing to learn.")
elif self.data.X.size == 0:
self.error(self.DATA_ERROR_ID,
"Data has no features to learn from.")
else:
self.valid_data = True
return self.valid_data
def settings_changed(self, *args, **kwargs):
self.outdated_settings = True
self.warning(self.OUTDATED_LEARNER_WARNING_ID,
None if self.auto_apply else "Press Apply to submit changes.")
self.apply()
def send_report(self):
self.report_items((("Name", self.learner_name),))
model_parameters = self.get_learner_parameters()
if model_parameters:
self.report_items("Model parameters", model_parameters)
if self.data:
self.report_data("Data", self.data)
# GUI
def setup_layout(self):
self.add_learner_name_widget()
self.add_main_layout()
self.add_bottom_buttons()
def add_main_layout(self):
"""Creates layout with the learner configuration widgets.
Override this method for laying out any learner-specific parameter controls.
See setup_layout() method for execution order.
"""
pass
def add_learner_name_widget(self):
gui.lineEdit(self.controlArea, self, 'learner_name', box='Name',
tooltip='The name will identify this model in other widgets',
orientation=Qt.Horizontal,
callback=lambda: self.apply())
def add_bottom_buttons(self):
box = gui.hBox(self.controlArea, True)
box.layout().addWidget(self.report_button)
gui.separator(box, 15)
self.apply_button = gui.auto_commit(box, self, 'auto_apply', '&Apply',
box=False, commit=self.apply)