diff --git a/Orange/widgets/visualize/owscatterplot.py b/Orange/widgets/visualize/owscatterplot.py index 7af55a84837..bd3ef3ec29d 100644 --- a/Orange/widgets/visualize/owscatterplot.py +++ b/Orange/widgets/visualize/owscatterplot.py @@ -1,3 +1,5 @@ +from itertools import chain + import numpy as np from AnyQt.QtCore import Qt, QTimer @@ -8,7 +10,7 @@ from sklearn.metrics import r2_score import Orange -from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable +from Orange.data import Table, Domain, DiscreteVariable from Orange.canvas import report from Orange.data.sql.table import SqlTable, AUTO_DL_LIMIT from Orange.preprocess.score import ReliefF, RReliefF @@ -26,15 +28,24 @@ class ScatterPlotVizRank(VizRankDialogAttrPair): captionTitle = "Score Plots" minK = 10 + attr_color = None + + def __init__(self, master): + super().__init__(master) + self.attr_color = self.master.graph.attr_color + + def initialize(self): + self.attr_color = self.master.graph.attr_color + super().initialize() def check_preconditions(self): self.Information.add_message( - "class_required", "Data with a class variable is required.") - self.Information.class_required.clear() + "color_required", "Color variable must be selected") + self.Information.color_required.clear() if not super().check_preconditions(): return False - if not self.master.data.domain.class_var: - self.Information.class_required() + if not self.attr_color: + self.Information.color_required() return False return True @@ -46,43 +57,38 @@ def iterate_states(self, initial_state): yield from super().iterate_states(initial_state) def compute_score(self, state): - graph = self.master.graph - attrs = [self.attrs[x] for x in state] - valid = graph.get_valid_list(attrs) - cols = [] - for var in attrs: - cols.append(graph.jittered_data.get_column_view(var)[0][valid]) - X = np.column_stack(cols) - Y = self.master.data.Y[valid] - if X.shape[0] < self.minK: + attrs = [self.attrs[i] for i in state] + data = self.master.graph.scaled_data + data = data.transform(Domain(attrs, self.attr_color)) + data = data[~np.isnan(data.X).any(axis=1) & ~np.isnan(data.Y).T] + if len(data) < self.minK: return - n_neighbors = min(self.minK, len(X) - 1) - knn = NearestNeighbors(n_neighbors=n_neighbors).fit(X) + n_neighbors = min(self.minK, len(data) - 1) + knn = NearestNeighbors(n_neighbors=n_neighbors).fit(data.X) ind = knn.kneighbors(return_distance=False) - if self.master.data.domain.has_discrete_class: - return -np.sum(Y[ind] == Y.reshape(-1, 1)) / n_neighbors / len(Y) + if data.domain.has_discrete_class: + return -np.sum(data.Y[ind] == data.Y.reshape(-1, 1)) / n_neighbors / len(data.Y) else: - return -r2_score(Y, np.mean(Y[ind], axis=1)) * \ - (len(Y) / len(self.master.data)) + return -r2_score(data.Y, np.mean(data.Y[ind], axis=1)) * \ + (len(data.Y) / len(self.master.data)) def bar_length(self, score): return max(0, -score) def score_heuristic(self): - X = self.master.graph.jittered_data.X - Y = self.master.data.Y - mdomain = self.master.data.domain - dom = Domain([ContinuousVariable(str(i)) for i in range(X.shape[1])], - mdomain.class_vars) - data = Table(dom, X, Y) - relief = ReliefF if isinstance(dom.class_var, DiscreteVariable) \ - else RReliefF + assert self.attr_color is not None + master_domain = self.master.graph.scaled_data.domain + vars = [v for v in chain(master_domain.variables, master_domain.metas) + if v is not self.attr_color] + domain = Domain(attributes=vars, class_vars=self.attr_color) + data = self.master.graph.scaled_data.transform(domain) + relief = ReliefF if isinstance(domain.class_var, DiscreteVariable) else RReliefF weights = relief(n_iterations=100, k_nearest=self.minK)(data) - attrs = sorted(zip(weights, mdomain.attributes), - key=lambda x: (-x[0], x[1].name)) + attrs = sorted(zip(weights, domain.attributes), key=lambda x: (-x[0], x[1].name)) return [a for _, a in attrs] + class OWScatterPlot(OWWidget): """Scatterplot visualization with explorative analysis and intelligent data visualization enhancements.""" @@ -216,6 +222,20 @@ def reset_graph_data(self, *_): self.graph.rescale_data() self.update_graph() + def _vizrank_color_change(self): + self.vizrank.initialize() + is_enabled = self.data is not None and not self.data.is_sparse() and \ + len([v for v in chain(self.data.domain.variables, self.data.domain.metas) + if v.is_primitive]) > 2\ + and len(self.data) > 1 + self.vizrank_button.setEnabled( + is_enabled and self.graph.attr_color is not None and + not np.isnan(self.data.get_column_view(self.graph.attr_color)[0].astype(float)).all()) + if is_enabled and self.graph.attr_color is None: + self.vizrank_button.setToolTip("Color variable has to be selected.") + else: + self.vizrank_button.setToolTip("") + @Inputs.data def set_data(self, data): self.clear_messages() @@ -248,19 +268,8 @@ def set_data(self, data): if not same_domain: self.init_attr_values() - self.vizrank.initialize() - self.vizrank.attrs = self.data.domain.attributes if self.data is not None else [] - self.vizrank_button.setEnabled( - self.data is not None and not self.data.is_sparse() and - self.data.domain.class_var is not None and not np.isnan(self.data.Y).all() and - len(self.data.domain.attributes) > 1 and len(self.data) > 1) - if self.data is not None and self.data.domain.class_var is None \ - and len(self.data.domain.attributes) > 1 and len(self.data) > 1: - self.vizrank_button.setToolTip( - "Data with a class variable is required.") - else: - self.vizrank_button.setToolTip("") self.openContext(self.data) + self._vizrank_color_change() def findvar(name, iterable): """Find a Orange.data.Variable in `iterable` by name""" @@ -372,6 +381,7 @@ def update_attr(self): self.send_features() def update_colors(self): + self._vizrank_color_change() self.cb_class_density.setEnabled(self.graph.can_draw_density()) def update_density(self):