Skip to content

Commit eb39442

Browse files
committed
OWCorrelations: Move from prototypes to core
1 parent 024bd7f commit eb39442

File tree

6 files changed

+511
-0
lines changed

6 files changed

+511
-0
lines changed
Lines changed: 44 additions & 0 deletions
Loading
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
"""
2+
Correlations widget
3+
"""
4+
from enum import IntEnum
5+
from operator import attrgetter
6+
from itertools import combinations, groupby, chain
7+
8+
import numpy as np
9+
from scipy.stats import spearmanr, pearsonr
10+
from sklearn.cluster import KMeans
11+
12+
from AnyQt.QtCore import Qt, QItemSelectionModel, QItemSelection, QSize
13+
from AnyQt.QtGui import QStandardItem, QColor
14+
15+
from Orange.data import Table, Domain, ContinuousVariable, StringVariable
16+
from Orange.preprocess import SklImpute, Normalize
17+
from Orange.widgets import gui
18+
from Orange.widgets.settings import Setting, ContextSetting, \
19+
DomainContextHandler
20+
from Orange.widgets.utils.signals import Input, Output
21+
from Orange.widgets.visualize.utils import VizRankDialogAttrPair
22+
from Orange.widgets.widget import OWWidget, AttributeList, Msg
23+
24+
NAN = 2
25+
SIZE_LIMIT = 1000000
26+
27+
28+
class CorrelationType(IntEnum):
29+
"""
30+
Correlation type enumerator. Possible correlations: Pearson, Spearman.
31+
"""
32+
PEARSON, SPEARMAN = 0, 1
33+
34+
@staticmethod
35+
def items():
36+
"""
37+
Texts for correlation types. Can be used in gui controls (eg. combobox).
38+
"""
39+
return ["Pearson correlation", "Spearman correlation"]
40+
41+
42+
class KMeansCorrelationHeuristic:
43+
"""
44+
Heuristic to obtain the most promising attribute pairs, when there are to
45+
many attributes to calculate correlations for all possible pairs.
46+
"""
47+
n_clusters = 10
48+
49+
def __init__(self, data):
50+
self.n_attributes = len(data.domain.attributes)
51+
self.data = data
52+
self.states = None
53+
54+
def get_clusters_of_attributes(self):
55+
"""
56+
Generates groupes of attribute IDs, grouped by cluster. Clusters are
57+
obtained by KMeans algorithm.
58+
59+
:return: generator of attributes grouped by cluster
60+
"""
61+
data = Normalize()(self.data).X.T
62+
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0).fit(data)
63+
labels_attrs = sorted([(l, i) for i, l in enumerate(kmeans.labels_)])
64+
for _, group in groupby(labels_attrs, key=lambda x: x[0]):
65+
group = list(group)
66+
if len(group) > 1:
67+
yield list(pair[1] for pair in group)
68+
69+
def get_states(self, initial_state):
70+
"""
71+
Generates the most promising states (attribute pairs).
72+
73+
:param initial_state: initial state; None if this is the first call
74+
:return: generator of tuples of states
75+
"""
76+
if self.states is not None:
77+
return chain([initial_state], self.states)
78+
self.states = chain.from_iterable(combinations(inds, 2) for inds in
79+
self.get_clusters_of_attributes())
80+
return self.states
81+
82+
83+
class CorrelationRank(VizRankDialogAttrPair):
84+
"""
85+
Correlations rank widget.
86+
"""
87+
NEGATIVE_COLOR = QColor(70, 190, 250)
88+
POSITIVE_COLOR = QColor(170, 242, 43)
89+
90+
def __init__(self, *args):
91+
super().__init__(*args)
92+
self.heuristic = None
93+
self.use_heuristic = False
94+
95+
def initialize(self):
96+
super().initialize()
97+
data = self.master.cont_data
98+
self.attrs = data and data.domain.attributes
99+
self.model_proxy.setFilterKeyColumn(-1)
100+
self.rank_table.horizontalHeader().setStretchLastSection(False)
101+
self.heuristic = None
102+
self.use_heuristic = False
103+
if data:
104+
# use heuristic if data is too big
105+
n_attrs = len(self.attrs)
106+
use_heuristic = n_attrs > KMeansCorrelationHeuristic.n_clusters
107+
self.use_heuristic = use_heuristic and \
108+
len(data) * n_attrs ** 2 > SIZE_LIMIT
109+
if self.use_heuristic:
110+
self.heuristic = KMeansCorrelationHeuristic(data)
111+
112+
def compute_score(self, state):
113+
(attr1, attr2), corr_type = state, self.master.correlation_type
114+
data = self.master.cont_data.X
115+
corr = pearsonr if corr_type == CorrelationType.PEARSON else spearmanr
116+
result = corr(data[:, attr1], data[:, attr2])[0]
117+
return -abs(result) if not np.isnan(result) else NAN, result
118+
119+
def row_for_state(self, score, state):
120+
attrs = sorted((self.attrs[x] for x in state), key=attrgetter("name"))
121+
attrs_item = QStandardItem(
122+
"{}, {}".format(attrs[0].name, attrs[1].name))
123+
attrs_item.setData(attrs, self._AttrRole)
124+
attrs_item.setData(Qt.AlignLeft + Qt.AlignTop, Qt.TextAlignmentRole)
125+
correlation_item = QStandardItem("{:+.3f}".format(score[1]))
126+
correlation_item.setData(attrs, self._AttrRole)
127+
correlation_item.setData(
128+
self.NEGATIVE_COLOR if score[1] < 0 else self.POSITIVE_COLOR,
129+
gui.TableBarItem.BarColorRole)
130+
return [correlation_item, attrs_item]
131+
132+
def check_preconditions(self):
133+
return self.master.cont_data is not None
134+
135+
def iterate_states(self, initial_state):
136+
if self.use_heuristic:
137+
return self.heuristic.get_states(initial_state)
138+
else:
139+
return super().iterate_states(initial_state)
140+
141+
def state_count(self):
142+
if self.use_heuristic:
143+
n_clusters = KMeansCorrelationHeuristic.n_clusters
144+
n_avg_attrs = len(self.attrs) / n_clusters
145+
return n_clusters * n_avg_attrs * (n_avg_attrs - 1) / 2
146+
else:
147+
n_attrs = len(self.attrs)
148+
return n_attrs * (n_attrs - 1) / 2
149+
150+
@staticmethod
151+
def bar_length(score):
152+
return abs(score[1])
153+
154+
155+
class OWCorrelations(OWWidget):
156+
name = "Correlations"
157+
description = "Compute all pairwise attribute correlations."
158+
icon = "icons/Correlations.svg"
159+
priority = 1106
160+
161+
class Inputs:
162+
data = Input("Data", Table)
163+
164+
class Outputs:
165+
data = Output("Data", Table)
166+
features = Output("Features", AttributeList)
167+
correlations = Output("Correlations", Table)
168+
169+
want_control_area = False
170+
171+
settingsHandler = DomainContextHandler()
172+
selection = ContextSetting(())
173+
correlation_type = Setting(0)
174+
175+
class Information(OWWidget.Information):
176+
not_enough_vars = Msg("Need at least two continuous features.")
177+
not_enough_inst = Msg("Need at least two instances.")
178+
179+
def __init__(self):
180+
super().__init__()
181+
self.data = None
182+
self.cont_data = None
183+
184+
# GUI
185+
box = gui.vBox(self.mainArea)
186+
self.correlation_combo = gui.comboBox(
187+
box, self, "correlation_type", items=CorrelationType.items(),
188+
orientation=Qt.Horizontal, callback=self._correlation_combo_changed)
189+
190+
self.vizrank, _ = CorrelationRank.add_vizrank(
191+
None, self, None, self._vizrank_selection_changed)
192+
self.vizrank.progressBar = self.progressBar
193+
194+
gui.separator(box)
195+
box.layout().addWidget(self.vizrank.filter)
196+
box.layout().addWidget(self.vizrank.rank_table)
197+
198+
button_box = gui.hBox(self.mainArea)
199+
button_box.layout().addWidget(self.vizrank.button)
200+
201+
def sizeHint(self):
202+
return QSize(350, 400)
203+
204+
def _correlation_combo_changed(self):
205+
self.apply()
206+
207+
def _vizrank_selection_changed(self, *args):
208+
self.selection = args
209+
self.commit()
210+
211+
def _vizrank_select(self):
212+
model = self.vizrank.rank_table.model()
213+
selection = QItemSelection()
214+
names = sorted(x.name for x in self.selection)
215+
for i in range(model.rowCount()):
216+
if sorted(x.name for x in model.data(model.index(i, 0),
217+
CorrelationRank._AttrRole)) \
218+
== names:
219+
selection.select(model.index(i, 0), model.index(i, 1))
220+
self.vizrank.rank_table.selectionModel().select(
221+
selection, QItemSelectionModel.ClearAndSelect)
222+
break
223+
224+
@Inputs.data
225+
def set_data(self, data):
226+
self.closeContext()
227+
self.clear_messages()
228+
self.data = data
229+
self.cont_data = None
230+
self.selection = ()
231+
if data is not None:
232+
cont_attrs = [a for a in data.domain.attributes if a.is_continuous]
233+
if len(cont_attrs) < 2:
234+
self.Information.not_enough_vars()
235+
elif len(data) < 2:
236+
self.Information.not_enough_inst()
237+
else:
238+
domain = data.domain
239+
cont_dom = Domain(cont_attrs, domain.class_vars, domain.metas)
240+
self.cont_data = SklImpute()(Table.from_table(cont_dom, data))
241+
self.apply()
242+
self.openContext(self.data)
243+
self._vizrank_select()
244+
245+
def apply(self):
246+
self.vizrank.initialize()
247+
if self.cont_data is not None:
248+
# this triggers self.commit() by changing vizrank selection
249+
self.vizrank.toggle()
250+
header = self.vizrank.rank_table.horizontalHeader()
251+
header.setStretchLastSection(True)
252+
else:
253+
self.commit()
254+
255+
def commit(self):
256+
if self.data is None or self.cont_data is None:
257+
self.Outputs.data.send(self.data)
258+
self.Outputs.features.send(None)
259+
self.Outputs.correlations.send(None)
260+
return
261+
262+
metas = [StringVariable("Feature 1"), StringVariable("Feature 2")]
263+
domain = Domain([ContinuousVariable("Correlation")], metas=metas)
264+
model = self.vizrank.rank_model
265+
x = np.array([[float(model.data(model.index(row, 0)))] for row
266+
in range(model.rowCount())])
267+
m = np.array([[attr.name
268+
for attr in model.data(model.index(row, 0),
269+
CorrelationRank._AttrRole)]
270+
for row in range(model.rowCount())], dtype=object)
271+
corr_table = Table(domain, x, metas=m)
272+
corr_table.name = "Correlations"
273+
274+
self.Outputs.data.send(self.data)
275+
# data has been imputed; send original attributes
276+
self.Outputs.features.send(AttributeList([attr.compute_value.variable
277+
for attr in self.selection]))
278+
self.Outputs.correlations.send(corr_table)
279+
280+
def send_report(self):
281+
self.report_table(CorrelationType.items()[self.correlation_type],
282+
self.vizrank.rank_table)
283+
284+
285+
if __name__ == "__main__":
286+
from AnyQt.QtWidgets import QApplication
287+
288+
app = QApplication([])
289+
ow = OWCorrelations()
290+
iris = Table("iris")
291+
ow.set_data(iris)
292+
ow.show()
293+
app.exec_()
294+
ow.saveSettings()

0 commit comments

Comments
 (0)