Skip to content

Commit ba412dd

Browse files
committed
Scatter Plot: Error Bars
1 parent c2c1648 commit ba412dd

File tree

5 files changed

+541
-17
lines changed

5 files changed

+541
-17
lines changed

Orange/widgets/visualize/owscatterplot.py

Lines changed: 157 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import List, Callable
2+
from typing import List, Callable, Optional
33
from xml.sax.saxutils import escape
44

55
import numpy as np
@@ -10,7 +10,7 @@
1010

1111
from AnyQt.QtCore import Qt, QTimer, QPointF
1212
from AnyQt.QtGui import QColor, QFont
13-
from AnyQt.QtWidgets import QGroupBox
13+
from AnyQt.QtWidgets import QGroupBox, QSizePolicy, QPushButton
1414

1515
import pyqtgraph as pg
1616

@@ -29,6 +29,7 @@
2929
from Orange.widgets.utils.widgetpreview import WidgetPreview
3030
from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase, \
3131
ScatterBaseParameterSetter
32+
from Orange.widgets.visualize.utils.error_bars_dialog import ErrorBarsDialog
3233
from Orange.widgets.visualize.utils.vizrank import VizRankDialogAttrPair, \
3334
VizRankMixin
3435
from Orange.widgets.visualize.utils.customizableplot import Updater
@@ -150,15 +151,20 @@ def __init__(self, scatter_widget, parent):
150151
self.parameter_setter = ParameterSetter(self)
151152
self.reg_line_items = []
152153
self.ellipse_items: List[pg.PlotCurveItem] = []
154+
self.error_bars_items: List[pg.ErrorBarItem] = []
155+
self.view_box.sigResized.connect(self.update_error_bars)
156+
self.view_box.sigRangeChanged.connect(self.update_error_bars)
153157

154158
def clear(self):
155159
super().clear()
156160
self.reg_line_items.clear()
157161
self.ellipse_items.clear()
162+
self.error_bars_items.clear()
158163

159164
def update_coordinates(self):
160165
super().update_coordinates()
161166
self.update_axes()
167+
self.update_error_bars()
162168
# Don't update_regression line here: update_coordinates is always
163169
# followed by update_point_props, which calls update_colors
164170

@@ -168,6 +174,9 @@ def update_colors(self):
168174
self.update_ellipse()
169175

170176
def jitter_coordinates(self, x, y):
177+
if self.jitter_size == 0:
178+
return x, y
179+
171180
def get_span(attr):
172181
if attr.is_discrete:
173182
# Assuming the maximal jitter size is 10, a span of 4 will
@@ -179,7 +188,7 @@ def get_span(attr):
179188
return 0 # No jittering
180189
span_x = get_span(self.master.attr_x)
181190
span_y = get_span(self.master.attr_y)
182-
if self.jitter_size == 0 or (span_x == 0 and span_y == 0):
191+
if span_x == 0 and span_y == 0:
183192
return x, y
184193
return self._jitter_data(x, y, span_x, span_y)
185194

@@ -333,6 +342,42 @@ def _add_ellipse(self, x: np.ndarray, y: np.ndarray, color: QColor) -> np.ndarra
333342
self.plot_widget.addItem(ellipse)
334343
self.ellipse_items.append(ellipse)
335344

345+
def update_jittering(self):
346+
super().update_jittering()
347+
self.update_error_bars()
348+
349+
def update_error_bars(self):
350+
for item in self.error_bars_items:
351+
self.plot_widget.removeItem(item)
352+
self.error_bars_items.clear()
353+
if not self.master.can_draw_regression_line():
354+
return
355+
356+
x, y = self.get_coordinates()
357+
if x is None:
358+
return
359+
360+
top, bottom, left, right = self.master.get_errors_data()
361+
if top is None and bottom is None and left is None and right is None:
362+
return
363+
364+
px, py = self.view_box.viewPixelSize()
365+
pen = pg.mkPen(color=QColor("#505050"))
366+
367+
# x axis
368+
error_bars = pg.ErrorBarItem(x=x, y=y, left=left, right=right,
369+
beam=py * 10, pen=pen)
370+
error_bars.setZValue(-1)
371+
self.plot_widget.addItem(error_bars)
372+
self.error_bars_items.append(error_bars)
373+
374+
# y axis
375+
error_bars = pg.ErrorBarItem(x=x, y=y, top=top, bottom=bottom,
376+
beam=px * 10, pen=pen)
377+
error_bars.setZValue(-1)
378+
self.plot_widget.addItem(error_bars)
379+
self.error_bars_items.append(error_bars)
380+
336381

337382
class OWScatterPlot(OWDataProjectionWidget, VizRankMixin(ScatterPlotVizRank)):
338383
"""Scatterplot visualization with explorative analysis and intelligent
@@ -355,6 +400,12 @@ class Outputs(OWDataProjectionWidget.Outputs):
355400
auto_sample = Setting(True)
356401
attr_x = ContextSetting(None)
357402
attr_y = ContextSetting(None)
403+
attr_x_upper = ContextSetting(None)
404+
attr_x_lower = ContextSetting(None)
405+
attr_x_is_abs = Setting(False)
406+
attr_y_upper = ContextSetting(None)
407+
attr_y_lower = ContextSetting(None)
408+
attr_y_is_abs = Setting(False)
358409
tooltip_shows_all = Setting(True)
359410

360411
GRAPH_CLASS = OWScatterPlotGraph
@@ -376,6 +427,10 @@ def __init__(self):
376427
self.xy_model: DomainModel = None
377428
self.cb_attr_x: ComboBoxSearch = None
378429
self.cb_attr_y: ComboBoxSearch = None
430+
self.button_attr_x: QPushButton = None
431+
self.button_attr_y: QPushButton = None
432+
self.__x_axis_dlg: ErrorBarsDialog = None
433+
self.__y_axis_dlg: ErrorBarsDialog = None
379434
self.sampling: QGroupBox = None
380435
self._xy_invalidated: bool = True
381436

@@ -425,37 +480,82 @@ def _add_controls_axis(self):
425480
spacing=2 if gui.is_macstyle() else 8)
426481
dmod = DomainModel
427482
self.xy_model = DomainModel(dmod.MIXED, valid_types=dmod.PRIMITIVE)
483+
484+
hbox = gui.hBox(self.attr_box, spacing=2)
428485
self.cb_attr_x = gui.comboBox(
429-
self.attr_box, self, "attr_x", label="Axis x:",
486+
hbox, self, "attr_x", label="Axis x:",
430487
callback=self.set_attr_from_combo,
431488
model=self.xy_model, **common_options,
432489
)
490+
self.button_attr_x = gui.button(
491+
hbox, self, "ϵ", callback=self.__on_x_button_clicked,
492+
autoDefault=False, width=20, enabled=False,
493+
sizePolicy=(QSizePolicy.Maximum, QSizePolicy.Maximum)
494+
)
495+
496+
hbox = gui.hBox(self.attr_box, spacing=2)
433497
self.cb_attr_y = gui.comboBox(
434-
self.attr_box, self, "attr_y", label="Axis y:",
498+
hbox, self, "attr_y", label="Axis y:",
435499
callback=self.set_attr_from_combo,
436500
model=self.xy_model, **common_options,
437501
)
502+
self.button_attr_y = gui.button(
503+
hbox, self, "ϵ", callback=self.__on_y_button_clicked,
504+
autoDefault=False, width=20, enabled=False,
505+
sizePolicy=(QSizePolicy.Maximum, QSizePolicy.Maximum)
506+
)
507+
438508
vizrank_box = gui.hBox(self.attr_box)
439509
button = self.vizrank_button("Find Informative Projections")
440510
vizrank_box.layout().addWidget(button)
441511
self.vizrankSelectionChanged.connect(self.set_attr)
442512

513+
self.__x_axis_dlg = ErrorBarsDialog(self, "Axis x Error Bars")
514+
self.__x_axis_dlg.changed.connect(self.__on_x_dlg_changed)
515+
self.__y_axis_dlg = ErrorBarsDialog(self, "Axis y Error Bars")
516+
self.__y_axis_dlg.changed.connect(self.__on_y_dlg_changed)
517+
518+
def __on_x_button_clicked(self):
519+
self.__x_axis_dlg.show_dlg(self.data.domain, self.attr_x_upper,
520+
self.attr_x_lower, self.attr_x_is_abs)
521+
522+
def __on_y_button_clicked(self):
523+
self.__y_axis_dlg.show_dlg(self.data.domain, self.attr_y_upper,
524+
self.attr_y_lower, self.attr_y_is_abs)
525+
526+
def __on_x_dlg_changed(self):
527+
self.attr_x_upper, self.attr_x_lower, self.attr_x_is_abs = \
528+
self.__x_axis_dlg.get_data()
529+
self.graph.update_error_bars()
530+
531+
def __on_y_dlg_changed(self):
532+
self.attr_y_upper, self.attr_y_lower, self.attr_y_is_abs = \
533+
self.__y_axis_dlg.get_data()
534+
self.graph.update_error_bars()
535+
443536
def _add_controls_sampling(self):
444537
self.sampling = gui.auto_commit(
445538
self.controlArea, self, "auto_sample", "Sample", box="Sampling",
446539
callback=self.switch_sampling, commit=lambda: self.add_data(1))
447540
self.sampling.setVisible(False)
448541

449542
@property
450-
def effective_variables(self):
451-
return [self.attr_x, self.attr_y] if self.attr_x and self.attr_y else []
543+
def effective_variables(self) -> list[Variable]:
544+
variables = []
545+
if self.attr_x and self.attr_y:
546+
variables.append(self.attr_x)
547+
if self.attr_x.name != self.attr_y.name:
548+
variables.append(self.attr_y)
549+
for var in (self.attr_x_upper, self.attr_x_lower,
550+
self.attr_y_upper, self.attr_y_lower):
551+
# set is not used to preserve order
552+
if var and var not in variables:
553+
variables.append(var)
554+
return variables
452555

453556
@property
454557
def effective_data(self):
455-
eff_var = self.effective_variables
456-
if eff_var and self.attr_x.name == self.attr_y.name:
457-
eff_var = [self.attr_x]
458-
return self.data.transform(Domain(eff_var))
558+
return self.data.transform(Domain(self.effective_variables))
459559

460560
def init_vizrank(self):
461561
err_msg = ""
@@ -523,6 +623,14 @@ def check_data(self):
523623
len(self.data.domain.variables) == 0):
524624
self.data = None
525625

626+
def enable_controls(self):
627+
super().enable_controls()
628+
enabled = bool(self.data) and \
629+
self.data.domain.has_continuous_attributes(include_class=True,
630+
include_metas=True)
631+
self.button_attr_x.setEnabled(enabled)
632+
self.button_attr_y.setEnabled(enabled)
633+
526634
def get_embedding(self):
527635
self.valid_data = None
528636
if self.data is None:
@@ -541,6 +649,31 @@ def get_embedding(self):
541649
msg.missing_coords(self.attr_x.name, self.attr_y.name)
542650
return np.vstack((x_data, y_data)).T
543651

652+
def get_errors_data(self) -> tuple[
653+
Optional[np.ndarray], Optional[np.ndarray],
654+
Optional[np.ndarray], Optional[np.ndarray]
655+
]:
656+
x_data = self.get_column(self.attr_x)
657+
y_data = self.get_column(self.attr_y)
658+
top, bottom, left, right = [None] * 4
659+
if self.attr_x_upper:
660+
right = self.get_column(self.attr_x_upper)
661+
if self.attr_x_is_abs:
662+
right = right - x_data
663+
if self.attr_x_lower:
664+
left = self.get_column(self.attr_x_lower)
665+
if self.attr_x_is_abs:
666+
left = x_data - left
667+
if self.attr_y_upper:
668+
top = self.get_column(self.attr_y_upper)
669+
if self.attr_y_is_abs:
670+
top = top - y_data
671+
if self.attr_y_lower:
672+
bottom = self.get_column(self.attr_y_lower)
673+
if self.attr_y_is_abs:
674+
bottom = y_data - bottom
675+
return top, bottom, left, right
676+
544677
# Tooltip
545678
def _point_tooltip(self, point_id, skip_attrs=()):
546679
point_data = self.data[point_id]
@@ -580,6 +713,8 @@ def init_attr_values(self):
580713
self.attr_x = self.xy_model[0] if self.xy_model else None
581714
self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
582715
else self.attr_x
716+
self.attr_x_upper, self.attr_x_lower = None, None
717+
self.attr_y_upper, self.attr_y_lower = None, None
583718

584719
def switch_sampling(self):
585720
self.__timer.stop()
@@ -588,15 +723,15 @@ def switch_sampling(self):
588723
self.__timer.start()
589724

590725
@OWDataProjectionWidget.Inputs.data_subset
591-
def set_subset_data(self, subset_data):
726+
def set_subset_data(self, subset: Optional[Table]):
592727
self.warning()
593-
if isinstance(subset_data, SqlTable):
594-
if subset_data.approx_len() < AUTO_DL_LIMIT:
595-
subset_data = Table(subset_data)
728+
if isinstance(subset, SqlTable):
729+
if subset.approx_len() < AUTO_DL_LIMIT:
730+
subset = Table(subset)
596731
else:
597732
self.warning("Data subset does not support large Sql tables")
598-
subset_data = None
599-
super().set_subset_data(subset_data)
733+
subset = None
734+
super().set_subset_data(subset)
600735

601736
# called when all signals are received, so the graph is updated only once
602737
def handleNewSignals(self):
@@ -608,12 +743,17 @@ def handleNewSignals(self):
608743
self.attr_x, self.attr_y = self.attribute_selection_list
609744
else:
610745
self.attr_x, self.attr_y = None, None
746+
self.attr_x_upper, self.attr_x_lower = None, None
747+
self.attr_y_upper, self.attr_y_lower = None, None
611748
self._invalidated = self._invalidated or self._xy_invalidated
612749
self._xy_invalidated = False
613750
super().handleNewSignals()
614751
if self._domain_invalidated:
615752
self.graph.update_axes()
753+
self.graph.update_error_bars()
616754
self._domain_invalidated = False
755+
if self.attribute_selection_list:
756+
self.graph.update_error_bars()
617757
can_plot = self.can_draw_regression_line()
618758
self.cb_reg_line.setEnabled(can_plot)
619759
self.graph.controls.show_ellipse.setEnabled(can_plot)

0 commit comments

Comments
 (0)