Skip to content

Commit d1daaf8

Browse files
committed
OWCorrelations: Move from prototypes to core
1 parent 024bd7f commit d1daaf8

File tree

6 files changed

+510
-0
lines changed

6 files changed

+510
-0
lines changed
Lines changed: 44 additions & 0 deletions
Loading
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
"""
2+
Correlations widget
3+
"""
4+
from enum import IntEnum
5+
from operator import attrgetter
6+
from itertools import combinations, groupby, chain
7+
8+
import numpy as np
9+
from scipy.stats import spearmanr, pearsonr
10+
from sklearn.cluster import KMeans
11+
12+
from AnyQt.QtCore import Qt, QItemSelectionModel, QItemSelection, QSize
13+
from AnyQt.QtGui import QStandardItem, QColor
14+
from AnyQt.QtWidgets import QApplication
15+
16+
from Orange.data import Table, Domain, ContinuousVariable, StringVariable
17+
from Orange.preprocess import SklImpute, Normalize
18+
from Orange.widgets import gui
19+
from Orange.widgets.settings import Setting, ContextSetting, \
20+
DomainContextHandler
21+
from Orange.widgets.utils.signals import Input, Output
22+
from Orange.widgets.visualize.utils import VizRankDialogAttrPair
23+
from Orange.widgets.widget import OWWidget, AttributeList, Msg
24+
25+
NAN = 2
26+
SIZE_LIMIT = 1000000
27+
28+
29+
class CorrelationType(IntEnum):
30+
"""
31+
Correlation type enumerator. Possible correlations: Pearson, Spearman.
32+
"""
33+
PEARSON, SPEARMAN = 0, 1
34+
35+
@staticmethod
36+
def items():
37+
"""
38+
Texts for correlation types. Can be used in gui controls (eg. combobox).
39+
"""
40+
return ["Pearson correlation", "Spearman correlation"]
41+
42+
43+
class KMeansCorrelationHeuristic:
44+
"""
45+
Heuristic to obtain the most promising attribute pairs, when there are to
46+
many attributes to calculate correlations for all possible pairs.
47+
"""
48+
n_clusters = 10
49+
50+
def __init__(self, data):
51+
self.n_attributes = len(data.domain.attributes)
52+
self.data = data
53+
self.states = None
54+
55+
def get_clusters_of_attributes(self):
56+
"""
57+
Generates groupes of attribute IDs, grouped by cluster. Clusters are
58+
obtained by KMeans algorithm.
59+
60+
:return: generator of attributes grouped by cluster
61+
"""
62+
data = Normalize()(self.data).X.T
63+
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0).fit(data)
64+
labels_attrs = sorted([(l, i) for i, l in enumerate(kmeans.labels_)])
65+
for _, group in groupby(labels_attrs, key=lambda x: x[0]):
66+
group = list(group)
67+
if len(group) > 1:
68+
yield list(pair[1] for pair in group)
69+
70+
def get_states(self, initial_state):
71+
"""
72+
Generates the most promising states (attribute pairs).
73+
74+
:param initial_state: initial state; None if this is the first call
75+
:return: generator of tuples of states
76+
"""
77+
if self.states is not None:
78+
return chain([initial_state], self.states)
79+
self.states = chain.from_iterable(combinations(inds, 2) for inds in
80+
self.get_clusters_of_attributes())
81+
return self.states
82+
83+
84+
class CorrelationRank(VizRankDialogAttrPair):
85+
"""
86+
Correlations rank widget.
87+
"""
88+
NEGATIVE_COLOR = QColor(70, 190, 250)
89+
POSITIVE_COLOR = QColor(170, 242, 43)
90+
91+
def __init__(self, *args):
92+
super().__init__(*args)
93+
self.heuristic = None
94+
self.use_heuristic = False
95+
96+
def initialize(self):
97+
super().initialize()
98+
data = self.master.cont_data
99+
self.attrs = data and data.domain.attributes
100+
self.model_proxy.setFilterKeyColumn(-1)
101+
self.rank_table.horizontalHeader().setStretchLastSection(False)
102+
self.heuristic = None
103+
self.use_heuristic = False
104+
if data:
105+
# use heuristic if data is too big
106+
n_attrs = len(self.attrs)
107+
use_heuristic = n_attrs > KMeansCorrelationHeuristic.n_clusters
108+
self.use_heuristic = use_heuristic and \
109+
len(data) * n_attrs ** 2 > SIZE_LIMIT
110+
if self.use_heuristic:
111+
self.heuristic = KMeansCorrelationHeuristic(data)
112+
113+
def compute_score(self, state):
114+
(attr1, attr2), corr_type = state, self.master.correlation_type
115+
data = self.master.cont_data.X
116+
corr = pearsonr if corr_type == CorrelationType.PEARSON else spearmanr
117+
result = corr(data[:, attr1], data[:, attr2])[0]
118+
return -abs(result) if not np.isnan(result) else NAN, result
119+
120+
def row_for_state(self, score, state):
121+
attrs = sorted((self.attrs[x] for x in state), key=attrgetter("name"))
122+
attrs_item = QStandardItem(
123+
"{}, {}".format(attrs[0].name, attrs[1].name))
124+
attrs_item.setData(attrs, self._AttrRole)
125+
attrs_item.setData(Qt.AlignLeft + Qt.AlignTop, Qt.TextAlignmentRole)
126+
correlation_item = QStandardItem("{:+.3f}".format(score[1]))
127+
correlation_item.setData(attrs, self._AttrRole)
128+
correlation_item.setData(
129+
self.NEGATIVE_COLOR if score[1] < 0 else self.POSITIVE_COLOR,
130+
gui.TableBarItem.BarColorRole)
131+
return [correlation_item, attrs_item]
132+
133+
def check_preconditions(self):
134+
return self.master.cont_data is not None
135+
136+
def iterate_states(self, initial_state):
137+
if self.use_heuristic:
138+
return self.heuristic.get_states(initial_state)
139+
else:
140+
return super().iterate_states(initial_state)
141+
142+
def state_count(self):
143+
if self.use_heuristic:
144+
n_clusters = KMeansCorrelationHeuristic.n_clusters
145+
n_avg_attrs = len(self.attrs) / n_clusters
146+
return n_clusters * n_avg_attrs * (n_avg_attrs - 1) / 2
147+
else:
148+
n_attrs = len(self.attrs)
149+
return n_attrs * (n_attrs - 1) / 2
150+
151+
@staticmethod
152+
def bar_length(score):
153+
return abs(score[1])
154+
155+
156+
class OWCorrelations(OWWidget):
157+
name = "Correlations"
158+
description = "Compute all pairwise attribute correlations."
159+
icon = "icons/Correlations.svg"
160+
priority = 1106
161+
162+
class Inputs:
163+
data = Input("Data", Table)
164+
165+
class Outputs:
166+
data = Output("Data", Table)
167+
features = Output("Features", AttributeList)
168+
correlations = Output("Correlations", Table)
169+
170+
want_control_area = False
171+
172+
settingsHandler = DomainContextHandler()
173+
selection = ContextSetting(())
174+
correlation_type = Setting(0)
175+
176+
class Information(OWWidget.Information):
177+
not_enough_vars = Msg("Need at least two continuous features.")
178+
not_enough_inst = Msg("Need at least two instances.")
179+
180+
def __init__(self):
181+
super().__init__()
182+
self.data = None
183+
self.cont_data = None
184+
185+
# GUI
186+
box = gui.vBox(self.mainArea)
187+
self.correlation_combo = gui.comboBox(
188+
box, self, "correlation_type", items=CorrelationType.items(),
189+
orientation=Qt.Horizontal, callback=self._correlation_combo_changed)
190+
191+
self.vizrank, _ = CorrelationRank.add_vizrank(
192+
None, self, None, self._vizrank_selection_changed)
193+
self.vizrank.progressBar = self.progressBar
194+
195+
gui.separator(box)
196+
box.layout().addWidget(self.vizrank.filter)
197+
box.layout().addWidget(self.vizrank.rank_table)
198+
199+
button_box = gui.hBox(self.mainArea)
200+
button_box.layout().addWidget(self.vizrank.button)
201+
202+
def sizeHint(self):
203+
return QSize(350, 400)
204+
205+
def _correlation_combo_changed(self):
206+
self.apply()
207+
208+
def _vizrank_selection_changed(self, *args):
209+
self.selection = args
210+
self.commit()
211+
212+
def _vizrank_select(self):
213+
model = self.vizrank.rank_table.model()
214+
selection = QItemSelection()
215+
names = sorted(x.name for x in self.selection)
216+
for i in range(model.rowCount()):
217+
if sorted(x.name for x in model.data(model.index(i, 0),
218+
CorrelationRank._AttrRole)) \
219+
== names:
220+
selection.select(model.index(i, 0), model.index(i, 1))
221+
self.vizrank.rank_table.selectionModel().select(
222+
selection, QItemSelectionModel.ClearAndSelect)
223+
break
224+
225+
@Inputs.data
226+
def set_data(self, data):
227+
self.closeContext()
228+
self.clear_messages()
229+
self.data = data
230+
self.cont_data = None
231+
self.selection = ()
232+
if data is not None:
233+
cont_attrs = [a for a in data.domain.attributes if a.is_continuous]
234+
if len(cont_attrs) < 2:
235+
self.Information.not_enough_vars()
236+
elif len(data) < 2:
237+
self.Information.not_enough_inst()
238+
else:
239+
domain = data.domain
240+
cont_dom = Domain(cont_attrs, domain.class_vars, domain.metas)
241+
self.cont_data = SklImpute()(Table.from_table(cont_dom, data))
242+
self.apply()
243+
self.openContext(self.data)
244+
self._vizrank_select()
245+
246+
def apply(self):
247+
self.vizrank.initialize()
248+
if self.cont_data is not None:
249+
# this triggers self.commit() by changing vizrank selection
250+
self.vizrank.toggle()
251+
header = self.vizrank.rank_table.horizontalHeader()
252+
header.setStretchLastSection(True)
253+
else:
254+
self.commit()
255+
256+
def commit(self):
257+
if self.data is None or self.cont_data is None:
258+
self.Outputs.data.send(self.data)
259+
self.Outputs.features.send(None)
260+
self.Outputs.correlations.send(None)
261+
return
262+
263+
metas = [StringVariable("Feature 1"), StringVariable("Feature 2")]
264+
domain = Domain([ContinuousVariable("Correlation")], metas=metas)
265+
model = self.vizrank.rank_model
266+
x = np.array([[float(model.data(model.index(row, 0)))] for row
267+
in range(model.rowCount())])
268+
m = np.array([[attr.name
269+
for attr in model.data(model.index(row, 0),
270+
CorrelationRank._AttrRole)]
271+
for row in range(model.rowCount())], dtype=object)
272+
corr_table = Table(domain, x, metas=m)
273+
corr_table.name = "Correlations"
274+
275+
self.Outputs.data.send(self.data)
276+
# data has been imputed; send original attributes
277+
self.Outputs.features.send(AttributeList([attr.compute_value.variable
278+
for attr in self.selection]))
279+
self.Outputs.correlations.send(corr_table)
280+
281+
def send_report(self):
282+
self.report_table(CorrelationType.items()[self.correlation_type],
283+
self.vizrank.rank_table)
284+
285+
286+
if __name__ == "__main__":
287+
app = QApplication([])
288+
ow = OWCorrelations()
289+
iris = Table("iris")
290+
ow.set_data(iris)
291+
ow.show()
292+
app.exec_()
293+
ow.saveSettings()

0 commit comments

Comments
 (0)