|
| 1 | +""" |
| 2 | +Correlations widget |
| 3 | +""" |
| 4 | +from enum import IntEnum |
| 5 | +from operator import attrgetter |
| 6 | +from itertools import combinations, groupby, chain |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +from scipy.stats import spearmanr, pearsonr |
| 10 | +from sklearn.cluster import KMeans |
| 11 | + |
| 12 | +from AnyQt.QtCore import Qt, QItemSelectionModel, QItemSelection, QSize |
| 13 | +from AnyQt.QtGui import QStandardItem, QColor |
| 14 | +from AnyQt.QtWidgets import QApplication |
| 15 | + |
| 16 | +from Orange.data import Table, Domain, ContinuousVariable, StringVariable |
| 17 | +from Orange.preprocess import SklImpute, Normalize |
| 18 | +from Orange.widgets import gui |
| 19 | +from Orange.widgets.settings import Setting, ContextSetting, \ |
| 20 | + DomainContextHandler |
| 21 | +from Orange.widgets.utils.signals import Input, Output |
| 22 | +from Orange.widgets.visualize.utils import VizRankDialogAttrPair |
| 23 | +from Orange.widgets.widget import OWWidget, AttributeList, Msg |
| 24 | + |
| 25 | +NAN = 2 |
| 26 | +SIZE_LIMIT = 1000000 |
| 27 | + |
| 28 | + |
| 29 | +class CorrelationType(IntEnum): |
| 30 | + """ |
| 31 | + Correlation type enumerator. Possible correlations: Pearson, Spearman. |
| 32 | + """ |
| 33 | + PEARSON, SPEARMAN = 0, 1 |
| 34 | + |
| 35 | + @staticmethod |
| 36 | + def items(): |
| 37 | + """ |
| 38 | + Texts for correlation types. Can be used in gui controls (eg. combobox). |
| 39 | + """ |
| 40 | + return ["Pearson correlation", "Spearman correlation"] |
| 41 | + |
| 42 | + |
| 43 | +class KMeansCorrelationHeuristic: |
| 44 | + """ |
| 45 | + Heuristic to obtain the most promising attribute pairs, when there are to |
| 46 | + many attributes to calculate correlations for all possible pairs. |
| 47 | + """ |
| 48 | + n_clusters = 10 |
| 49 | + |
| 50 | + def __init__(self, data): |
| 51 | + self.n_attributes = len(data.domain.attributes) |
| 52 | + self.data = data |
| 53 | + self.states = None |
| 54 | + |
| 55 | + def get_clusters_of_attributes(self): |
| 56 | + """ |
| 57 | + Generates groupes of attribute IDs, grouped by cluster. Clusters are |
| 58 | + obtained by KMeans algorithm. |
| 59 | +
|
| 60 | + :return: generator of attributes grouped by cluster |
| 61 | + """ |
| 62 | + data = Normalize()(self.data).X.T |
| 63 | + kmeans = KMeans(n_clusters=self.n_clusters, random_state=0).fit(data) |
| 64 | + labels_attrs = sorted([(l, i) for i, l in enumerate(kmeans.labels_)]) |
| 65 | + for _, group in groupby(labels_attrs, key=lambda x: x[0]): |
| 66 | + group = list(group) |
| 67 | + if len(group) > 1: |
| 68 | + yield list(pair[1] for pair in group) |
| 69 | + |
| 70 | + def get_states(self, initial_state): |
| 71 | + """ |
| 72 | + Generates the most promising states (attribute pairs). |
| 73 | +
|
| 74 | + :param initial_state: initial state; None if this is the first call |
| 75 | + :return: generator of tuples of states |
| 76 | + """ |
| 77 | + if self.states is not None: |
| 78 | + return chain([initial_state], self.states) |
| 79 | + self.states = chain.from_iterable(combinations(inds, 2) for inds in |
| 80 | + self.get_clusters_of_attributes()) |
| 81 | + return self.states |
| 82 | + |
| 83 | + |
| 84 | +class CorrelationRank(VizRankDialogAttrPair): |
| 85 | + """ |
| 86 | + Correlations rank widget. |
| 87 | + """ |
| 88 | + NEGATIVE_COLOR = QColor(70, 190, 250) |
| 89 | + POSITIVE_COLOR = QColor(170, 242, 43) |
| 90 | + |
| 91 | + def __init__(self, *args): |
| 92 | + super().__init__(*args) |
| 93 | + self.heuristic = None |
| 94 | + self.use_heuristic = False |
| 95 | + |
| 96 | + def initialize(self): |
| 97 | + super().initialize() |
| 98 | + data = self.master.cont_data |
| 99 | + self.attrs = data and data.domain.attributes |
| 100 | + self.model_proxy.setFilterKeyColumn(-1) |
| 101 | + self.rank_table.horizontalHeader().setStretchLastSection(False) |
| 102 | + self.heuristic = None |
| 103 | + self.use_heuristic = False |
| 104 | + if data: |
| 105 | + # use heuristic if data is too big |
| 106 | + n_attrs = len(self.attrs) |
| 107 | + use_heuristic = n_attrs > KMeansCorrelationHeuristic.n_clusters |
| 108 | + self.use_heuristic = use_heuristic and \ |
| 109 | + len(data) * n_attrs ** 2 > SIZE_LIMIT |
| 110 | + if self.use_heuristic: |
| 111 | + self.heuristic = KMeansCorrelationHeuristic(data) |
| 112 | + |
| 113 | + def compute_score(self, state): |
| 114 | + (attr1, attr2), corr_type = state, self.master.correlation_type |
| 115 | + data = self.master.cont_data.X |
| 116 | + corr = pearsonr if corr_type == CorrelationType.PEARSON else spearmanr |
| 117 | + result = corr(data[:, attr1], data[:, attr2])[0] |
| 118 | + return -abs(result) if not np.isnan(result) else NAN, result |
| 119 | + |
| 120 | + def row_for_state(self, score, state): |
| 121 | + attrs = sorted((self.attrs[x] for x in state), key=attrgetter("name")) |
| 122 | + attrs_item = QStandardItem( |
| 123 | + "{}, {}".format(attrs[0].name, attrs[1].name)) |
| 124 | + attrs_item.setData(attrs, self._AttrRole) |
| 125 | + attrs_item.setData(Qt.AlignLeft + Qt.AlignTop, Qt.TextAlignmentRole) |
| 126 | + correlation_item = QStandardItem("{:+.3f}".format(score[1])) |
| 127 | + correlation_item.setData(attrs, self._AttrRole) |
| 128 | + correlation_item.setData( |
| 129 | + self.NEGATIVE_COLOR if score[1] < 0 else self.POSITIVE_COLOR, |
| 130 | + gui.TableBarItem.BarColorRole) |
| 131 | + return [correlation_item, attrs_item] |
| 132 | + |
| 133 | + def check_preconditions(self): |
| 134 | + return self.master.cont_data is not None |
| 135 | + |
| 136 | + def iterate_states(self, initial_state): |
| 137 | + if self.use_heuristic: |
| 138 | + return self.heuristic.get_states(initial_state) |
| 139 | + else: |
| 140 | + return super().iterate_states(initial_state) |
| 141 | + |
| 142 | + def state_count(self): |
| 143 | + if self.use_heuristic: |
| 144 | + n_clusters = KMeansCorrelationHeuristic.n_clusters |
| 145 | + n_avg_attrs = len(self.attrs) / n_clusters |
| 146 | + return n_clusters * n_avg_attrs * (n_avg_attrs - 1) / 2 |
| 147 | + else: |
| 148 | + n_attrs = len(self.attrs) |
| 149 | + return n_attrs * (n_attrs - 1) / 2 |
| 150 | + |
| 151 | + @staticmethod |
| 152 | + def bar_length(score): |
| 153 | + return abs(score[1]) |
| 154 | + |
| 155 | + |
| 156 | +class OWCorrelations(OWWidget): |
| 157 | + name = "Correlations" |
| 158 | + description = "Compute all pairwise attribute correlations." |
| 159 | + icon = "icons/Correlations.svg" |
| 160 | + priority = 1106 |
| 161 | + |
| 162 | + class Inputs: |
| 163 | + data = Input("Data", Table) |
| 164 | + |
| 165 | + class Outputs: |
| 166 | + data = Output("Data", Table) |
| 167 | + features = Output("Features", AttributeList) |
| 168 | + correlations = Output("Correlations", Table) |
| 169 | + |
| 170 | + want_control_area = False |
| 171 | + |
| 172 | + settingsHandler = DomainContextHandler() |
| 173 | + selection = ContextSetting(()) |
| 174 | + correlation_type = Setting(0) |
| 175 | + |
| 176 | + class Information(OWWidget.Information): |
| 177 | + not_enough_vars = Msg("Need at least two continuous features.") |
| 178 | + not_enough_inst = Msg("Need at least two instances.") |
| 179 | + |
| 180 | + def __init__(self): |
| 181 | + super().__init__() |
| 182 | + self.data = None |
| 183 | + self.cont_data = None |
| 184 | + |
| 185 | + # GUI |
| 186 | + box = gui.vBox(self.mainArea) |
| 187 | + self.correlation_combo = gui.comboBox( |
| 188 | + box, self, "correlation_type", items=CorrelationType.items(), |
| 189 | + orientation=Qt.Horizontal, callback=self._correlation_combo_changed) |
| 190 | + |
| 191 | + self.vizrank, _ = CorrelationRank.add_vizrank( |
| 192 | + None, self, None, self._vizrank_selection_changed) |
| 193 | + self.vizrank.progressBar = self.progressBar |
| 194 | + |
| 195 | + gui.separator(box) |
| 196 | + box.layout().addWidget(self.vizrank.filter) |
| 197 | + box.layout().addWidget(self.vizrank.rank_table) |
| 198 | + |
| 199 | + button_box = gui.hBox(self.mainArea) |
| 200 | + button_box.layout().addWidget(self.vizrank.button) |
| 201 | + |
| 202 | + def sizeHint(self): |
| 203 | + return QSize(350, 400) |
| 204 | + |
| 205 | + def _correlation_combo_changed(self): |
| 206 | + self.apply() |
| 207 | + |
| 208 | + def _vizrank_selection_changed(self, *args): |
| 209 | + self.selection = args |
| 210 | + self.commit() |
| 211 | + |
| 212 | + def _vizrank_select(self): |
| 213 | + model = self.vizrank.rank_table.model() |
| 214 | + selection = QItemSelection() |
| 215 | + names = sorted(x.name for x in self.selection) |
| 216 | + for i in range(model.rowCount()): |
| 217 | + # pylint: disable=protected-access |
| 218 | + if sorted(x.name for x in model.data( |
| 219 | + model.index(i, 0), CorrelationRank._AttrRole)) == names: |
| 220 | + selection.select(model.index(i, 0), model.index(i, 1)) |
| 221 | + self.vizrank.rank_table.selectionModel().select( |
| 222 | + selection, QItemSelectionModel.ClearAndSelect) |
| 223 | + break |
| 224 | + |
| 225 | + @Inputs.data |
| 226 | + def set_data(self, data): |
| 227 | + self.closeContext() |
| 228 | + self.clear_messages() |
| 229 | + self.data = data |
| 230 | + self.cont_data = None |
| 231 | + self.selection = () |
| 232 | + if data is not None: |
| 233 | + cont_attrs = [a for a in data.domain.attributes if a.is_continuous] |
| 234 | + if len(cont_attrs) < 2: |
| 235 | + self.Information.not_enough_vars() |
| 236 | + elif len(data) < 2: |
| 237 | + self.Information.not_enough_inst() |
| 238 | + else: |
| 239 | + domain = data.domain |
| 240 | + cont_dom = Domain(cont_attrs, domain.class_vars, domain.metas) |
| 241 | + self.cont_data = SklImpute()(Table.from_table(cont_dom, data)) |
| 242 | + self.apply() |
| 243 | + self.openContext(self.data) |
| 244 | + self._vizrank_select() |
| 245 | + |
| 246 | + def apply(self): |
| 247 | + self.vizrank.initialize() |
| 248 | + if self.cont_data is not None: |
| 249 | + # this triggers self.commit() by changing vizrank selection |
| 250 | + self.vizrank.toggle() |
| 251 | + header = self.vizrank.rank_table.horizontalHeader() |
| 252 | + header.setStretchLastSection(True) |
| 253 | + else: |
| 254 | + self.commit() |
| 255 | + |
| 256 | + def commit(self): |
| 257 | + if self.data is None or self.cont_data is None: |
| 258 | + self.Outputs.data.send(self.data) |
| 259 | + self.Outputs.features.send(None) |
| 260 | + self.Outputs.correlations.send(None) |
| 261 | + return |
| 262 | + |
| 263 | + metas = [StringVariable("Feature 1"), StringVariable("Feature 2")] |
| 264 | + domain = Domain([ContinuousVariable("Correlation")], metas=metas) |
| 265 | + model = self.vizrank.rank_model |
| 266 | + x = np.array([[float(model.data(model.index(row, 0)))] for row |
| 267 | + in range(model.rowCount())]) |
| 268 | + # pylint: disable=protected-access |
| 269 | + m = np.array([[a.name for a in model.data(model.index(row, 0), |
| 270 | + CorrelationRank._AttrRole)] |
| 271 | + for row in range(model.rowCount())], dtype=object) |
| 272 | + corr_table = Table(domain, x, metas=m) |
| 273 | + corr_table.name = "Correlations" |
| 274 | + |
| 275 | + self.Outputs.data.send(self.data) |
| 276 | + # data has been imputed; send original attributes |
| 277 | + self.Outputs.features.send(AttributeList([attr.compute_value.variable |
| 278 | + for attr in self.selection])) |
| 279 | + self.Outputs.correlations.send(corr_table) |
| 280 | + |
| 281 | + def send_report(self): |
| 282 | + self.report_table(CorrelationType.items()[self.correlation_type], |
| 283 | + self.vizrank.rank_table) |
| 284 | + |
| 285 | + |
| 286 | +if __name__ == "__main__": |
| 287 | + app = QApplication([]) |
| 288 | + ow = OWCorrelations() |
| 289 | + iris = Table("iris") |
| 290 | + ow.set_data(iris) |
| 291 | + ow.show() |
| 292 | + app.exec_() |
| 293 | + ow.saveSettings() |
0 commit comments