Skip to content

Commit 581073d

Browse files
committed
Violin Plot: Support scaling
1 parent c9934ae commit 581073d

File tree

2 files changed

+262
-81
lines changed

2 files changed

+262
-81
lines changed

Orange/widgets/visualize/owviolinplot.py

Lines changed: 107 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# pylint: disable=too-many-lines
2+
from collections import namedtuple
13
from itertools import chain, count
24
from typing import List, Optional, Tuple, Set, Sequence
35

@@ -8,7 +10,7 @@
810
from AnyQt.QtCore import QItemSelection, QPointF, QRectF, QSize, Qt, Signal
911
from AnyQt.QtGui import QBrush, QColor, QPainter, QPainterPath, QPolygonF
1012
from AnyQt.QtWidgets import QCheckBox, QSizePolicy, QGraphicsRectItem, \
11-
QGraphicsSceneMouseEvent, QApplication, QWidget
13+
QGraphicsSceneMouseEvent, QApplication, QWidget, QComboBox
1214

1315
import pyqtgraph as pg
1416

@@ -31,6 +33,9 @@
3133
from Orange.widgets.visualize.utils.plotutils import AxisItem
3234
from Orange.widgets.widget import OWWidget, Input, Output, Msg
3335

36+
# scaling types
37+
AREA, COUNT, WIDTH = range(3)
38+
3439

3540
class ViolinPlotViewBox(pg.ViewBox):
3641
sigSelectionChanged = Signal(QPointF, QPointF, bool)
@@ -107,7 +112,7 @@ def fit_kernel(data: np.ndarray, kernel: str) -> \
107112
Tuple[Optional[KernelDensity], float]:
108113
assert np.all(np.isfinite(data))
109114

110-
if data.size < 2:
115+
if np.unique(data).size < 2:
111116
return None, 1
112117

113118
# obtain bandwidth
@@ -123,31 +128,51 @@ def fit_kernel(data: np.ndarray, kernel: str) -> \
123128
return kde, bw
124129

125130

131+
def scale_density(scale_type: int, density: np.ndarray, n_data: int,
132+
max_density: float) -> np.ndarray:
133+
if scale_type == AREA:
134+
return density
135+
elif scale_type == COUNT:
136+
return density * n_data / max_density
137+
elif scale_type == WIDTH:
138+
return density / max_density
139+
else:
140+
raise NotImplementedError
141+
142+
126143
class ViolinItem(pg.GraphicsObject):
144+
RugPlot = namedtuple("RugPlot", "support, density")
145+
127146
def __init__(self, data: np.ndarray, color: QColor, kernel: str,
128-
show_rug: bool, orientation: Qt.Orientations):
147+
scale: int, show_rug: bool, orientation: Qt.Orientations):
148+
self.__scale = scale
129149
self.__show_rug_plot = show_rug
130150
self.__orientation = orientation
131151

132152
kde, bw = fit_kernel(data, kernel)
133153
self.__kde: KernelDensity = kde
134154
self.__bandwidth: float = bw
135155

136-
self.__violin_path: QPainterPath = self._create_violin(data)
156+
path, max_density = self._create_violin(data)
157+
self.__violin_path: QPainterPath = path
137158
self.__violin_brush: QBrush = QBrush(color)
138159

139-
self.__rug_plot_data: Tuple = self._create_rug_plot(data)
160+
self.__rug_plot_data: ViolinItem.RugPlot = \
161+
self._create_rug_plot(data, max_density)
140162

141163
super().__init__()
142164

143165
@property
144-
def kde(self) -> KernelDensity:
145-
return self.__kde
166+
def density(self) -> np.ndarray:
167+
# density on unique data
168+
return self.__rug_plot_data.density
146169

147170
@property
148171
def violin_width(self) -> float:
149-
return self.boundingRect().width() if self.__orientation == Qt.Vertical \
172+
width = self.boundingRect().width() \
173+
if self.__orientation == Qt.Vertical \
150174
else self.boundingRect().height()
175+
return width or 1
151176

152177
def set_show_rug_plot(self, show: bool):
153178
self.__show_rug_plot = show
@@ -173,13 +198,15 @@ def paint(self, painter: QPainter, *_):
173198

174199
painter.restore()
175200

176-
def _create_violin(self, data: np.ndarray) -> QPainterPath:
201+
def _create_violin(self, data: np.ndarray) -> Tuple[QPainterPath, float]:
177202
if self.__kde is None:
178-
x, p = np.zeros(1), np.zeros(1)
203+
x, p, max_density = np.zeros(1), np.zeros(1), 0
179204
else:
180205
x = np.linspace(data.min() - self.__bandwidth * 2,
181206
data.max() + self.__bandwidth * 2, 1000)
182207
p = np.exp(self.__kde.score_samples(x.reshape(-1, 1)))
208+
max_density = p.max()
209+
p = scale_density(self.__scale, p, len(data), max_density)
183210

184211
if self.__orientation == Qt.Vertical:
185212
pts = [QPointF(pi, xi) for xi, pi in zip(x, p)]
@@ -192,15 +219,17 @@ def _create_violin(self, data: np.ndarray) -> QPainterPath:
192219
polygon = QPolygonF(pts)
193220
path = QPainterPath()
194221
path.addPolygon(polygon)
195-
return path
222+
return path, max_density
196223

197-
def _create_rug_plot(self, data: np.ndarray) -> Tuple:
198-
unique_data = np.unique(data)
224+
def _create_rug_plot(self, data: np.ndarray, max_density: float) -> Tuple:
199225
if self.__kde is None:
200-
return unique_data, np.zeros(unique_data.size)
226+
return self.RugPlot(data, np.zeros(data.size))
201227

202-
density = np.exp(self.__kde.score_samples(unique_data.reshape(-1, 1)))
203-
return unique_data, density
228+
n_data = len(data)
229+
data = np.unique(data) # to optimize scoring
230+
density = np.exp(self.__kde.score_samples(data.reshape(-1, 1)))
231+
density = scale_density(self.__scale, density, n_data, max_density)
232+
return self.RugPlot(data, density)
204233

205234

206235
class BoxItem(pg.GraphicsObject):
@@ -216,7 +245,7 @@ def __init__(self, data: np.ndarray, rect: QRectF,
216245
def boundingRect(self) -> QRectF:
217246
return self.__bounding_rect
218247

219-
def paint(self, painter: QPainter, _, widget: QWidget):
248+
def paint(self, painter: QPainter, _, widget: Optional[QWidget]):
220249
painter.save()
221250

222251
q0, q25, q75, q100 = self.__box_plot_data
@@ -227,7 +256,7 @@ def paint(self, painter: QPainter, _, widget: QWidget):
227256
quartile1 = QPointF(q0, 0), QPointF(q100, 0)
228257
quartile2 = QPointF(q25, 0), QPointF(q75, 0)
229258

230-
factor = widget.devicePixelRatio()
259+
factor = 1 if widget is None else widget.devicePixelRatio()
231260
painter.setPen(pg.mkPen(QColor(Qt.black), width=2 * factor))
232261
painter.drawLine(*quartile1)
233262
painter.setPen(pg.mkPen(QColor(Qt.black), width=6 * factor))
@@ -251,27 +280,38 @@ class MedianItem(pg.ScatterPlotItem):
251280
def __init__(self, data: np.ndarray, orientation: Qt.Orientations):
252281
self.__value = value = 0 if data.size == 0 else np.median(data)
253282
x, y = (0, value) if orientation == Qt.Vertical else (value, 0)
254-
super().__init__(x=[x], y=[y], size=4,
283+
super().__init__(x=[x], y=[y], size=5,
255284
pen=pg.mkPen(QColor(Qt.white)),
256285
brush=pg.mkBrush(QColor(Qt.white)))
257286

258287
@property
259288
def value(self) -> float:
260289
return self.__value
261290

291+
def setX(self, x: float):
292+
self.setData(x=[x], y=[self.value])
293+
294+
def setY(self, y: float):
295+
self.setData(x=[self.value], y=[y])
296+
262297

263298
class StripItem(pg.ScatterPlotItem):
264-
def __init__(self, data: np.ndarray, kde: KernelDensity,
299+
def __init__(self, data: np.ndarray, density: np.ndarray,
265300
color: QColor, orientation: Qt.Orientations):
266-
if kde is not None:
267-
lim = np.exp(kde.score_samples(data.reshape(-1, 1)))
268-
else:
269-
lim = np.zeros(data.size)
270-
x = np.random.RandomState(0).uniform(-lim, lim)
301+
_, indices = np.unique(data, return_inverse=True)
302+
density = density[indices]
303+
self.__xdata = x = np.random.RandomState(0).uniform(-density, density)
304+
self.__ydata = data
271305
x, y = (x, data) if orientation == Qt.Vertical else (data, x)
272306
color = color.lighter(150)
273307
super().__init__(x=x, y=y, size=5, brush=pg.mkBrush(color))
274308

309+
def setX(self, x: float):
310+
self.setData(x=self.__xdata + x, y=self.__ydata)
311+
312+
def setY(self, y: float):
313+
self.setData(x=self.__ydata, y=self.__xdata + y)
314+
275315

276316
class SelectionRect(pg.GraphicsObject):
277317
def __init__(self, rect: QRectF, orientation: Qt.Orientations):
@@ -315,9 +355,10 @@ def paint(self, painter: QPainter, *_):
315355

316356
class ViolinPlot(pg.PlotWidget):
317357
VIOLIN_PADDING_FACTOR = 1.25
358+
SELECTION_PADDING_FACTOR = 1.20
318359
selection_changed = Signal(list, list)
319360

320-
def __init__(self, parent: OWWidget, kernel: str,
361+
def __init__(self, parent: OWWidget, kernel: str, scale: int,
321362
orientation: Qt.Orientations, show_box_plot: bool,
322363
show_strip_plot: bool, show_rug_plot: bool, sort_items: bool):
323364

@@ -329,6 +370,7 @@ def __init__(self, parent: OWWidget, kernel: str,
329370

330371
# settings
331372
self.__kernel = kernel
373+
self.__scale = scale
332374
self.__orientation = orientation
333375
self.__show_box_plot = show_box_plot
334376
self.__show_strip_plot = show_strip_plot
@@ -396,6 +438,11 @@ def set_kernel(self, kernel: str):
396438
self.__kernel = kernel
397439
self._plot_data()
398440

441+
def set_scale(self, scale: int):
442+
if self.__scale != scale:
443+
self.__scale = scale
444+
self._plot_data()
445+
399446
def set_orientation(self, orientation: Qt.Orientations):
400447
if self.__orientation != orientation:
401448
self.__orientation = orientation
@@ -491,6 +538,9 @@ def _set_axes(self):
491538
self.getAxis("left" if vertical else "bottom").setLabel(value_title)
492539
self.getAxis("bottom" if vertical else "left").setLabel(group_title)
493540

541+
if self.__group_var is None:
542+
self.getAxis("bottom" if vertical else "left").setTicks([])
543+
494544
def _plot_data(self):
495545
# save selection ranges
496546
ranges = self._selection_ranges
@@ -516,7 +566,7 @@ def _plot_data(self):
516566
def _set_violin_item(self, values: np.ndarray, color: QColor):
517567
values = values[~np.isnan(values)]
518568

519-
violin = ViolinItem(values, color, self.__kernel,
569+
violin = ViolinItem(values, color, self.__kernel, self.__scale,
520570
self.__show_rug_plot, self.__orientation)
521571
self.addItem(violin)
522572
self.__violin_items.append(violin)
@@ -531,12 +581,13 @@ def _set_violin_item(self, values: np.ndarray, color: QColor):
531581
self.addItem(median)
532582
self.__median_items.append(median)
533583

534-
strip = StripItem(values, violin.kde, color, self.__orientation)
584+
strip = StripItem(values, violin.density, color, self.__orientation)
535585
strip.setVisible(self.__show_strip_plot)
536586
self.addItem(strip)
537587
self.__strip_items.append(strip)
538588

539-
width = violin.violin_width * self.VIOLIN_PADDING_FACTOR
589+
width = self._max_item_width * self.SELECTION_PADDING_FACTOR / \
590+
self.VIOLIN_PADDING_FACTOR
540591
if self.__orientation == Qt.Vertical:
541592
rect = QRectF(-width / 2, median.value, width, 0)
542593
else:
@@ -588,6 +639,10 @@ def _clear_selection(self):
588639

589640
def _update_selection(self, p1: QPointF, p2: QPointF, finished: bool):
590641
# When finished, emit selection_changed.
642+
if len(self.__selection_rects) == 0:
643+
return
644+
assert self._max_item_width > 0
645+
591646
rect = QRectF(p1, p2).normalized()
592647
if self.__orientation == Qt.Vertical:
593648
min_max = rect.y(), rect.y() + rect.height()
@@ -658,7 +713,8 @@ class Error(OWWidget.Error):
658713
not_enough_instances = Msg("Plotting requires at least two instances.")
659714

660715
KERNELS = ["gaussian", "epanechnikov", "linear"]
661-
KERNEL_LABELS = ["Normal kernel", "Epanechnikov kernel", "Linear kernel"]
716+
KERNEL_LABELS = ["Normal", "Epanechnikov", "Linear"]
717+
SCALE_LABELS = ["Area", "Count", "Width"]
662718

663719
settingsHandler = DomainContextHandler()
664720
value_var = ContextSetting(None)
@@ -671,6 +727,7 @@ class Error(OWWidget.Error):
671727
order_violins = Setting(False)
672728
orientation_index = Setting(1) # Vertical
673729
kernel_index = Setting(0) # Normal kernel
730+
scale_index = Setting(0) # Area
674731
selection_ranges = Setting([], schema_only=True)
675732
visual_settings = Setting({}, schema_only=True)
676733

@@ -686,6 +743,7 @@ def __init__(self):
686743
self._value_var_view: ListViewSearch = None
687744
self._group_var_view: ListViewSearch = None
688745
self._order_violins_cb: QCheckBox = None
746+
self._scale_combo: QComboBox = None
689747
self.selection = []
690748
self.__pending_selection: List = self.selection_ranges
691749

@@ -700,7 +758,8 @@ def setup_gui(self):
700758

701759
def _add_graph(self):
702760
box = gui.vBox(self.mainArea)
703-
self.graph = ViolinPlot(self, self.kernel, self.orientation,
761+
self.graph = ViolinPlot(self, self.kernel,
762+
self.scale_index, self.orientation,
704763
self.show_box_plot, self.show_strip_plot,
705764
self.show_rug_plot, self.order_violins)
706765
self.graph.selection_changed.connect(self.__selection_changed)
@@ -754,8 +813,7 @@ def _add_controls(self):
754813
callback=self.apply_group_var_sorting)
755814

756815
box = gui.vBox(self.controlArea, "Display",
757-
sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum),
758-
addSpace=False)
816+
sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum))
759817
gui.checkBox(box, self, "show_box_plot", "Box plot",
760818
callback=self.__show_box_plot_changed)
761819
gui.checkBox(box, self, "show_strip_plot", "Strip plot",
@@ -772,13 +830,15 @@ def _add_controls(self):
772830
callback=self.__orientation_changed)
773831

774832
box = gui.vBox(self.controlArea, "Density Estimation",
775-
sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum),
776-
addSpace=False)
833+
sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum))
777834
gui.comboBox(box, self, "kernel_index", items=self.KERNEL_LABELS,
835+
label="Kernel:", labelWidth=60, orientation=Qt.Horizontal,
778836
callback=self.__kernel_changed)
779-
780-
# stretch over buttonsArea
781-
self.left_side.layout().setStretch(0, 9999999)
837+
self._scale_combo = gui.comboBox(
838+
box, self, "scale_index", items=self.SCALE_LABELS,
839+
label="Scale:", labelWidth=60, orientation=Qt.Horizontal,
840+
callback=self.__scale_changed
841+
)
782842

783843
self._set_input_summary(None)
784844
self._set_output_summary(None)
@@ -796,7 +856,7 @@ def __group_var_changed(self, selection: QItemSelection):
796856
return
797857
self.group_var = selection.indexes()[0].data(gui.TableVariable)
798858
self.apply_value_var_sorting()
799-
self.enable_order_violins_cb()
859+
self.enable_controls()
800860
self.setup_plot()
801861
self.__selection_changed([], [])
802862

@@ -818,6 +878,9 @@ def __orientation_changed(self):
818878
def __kernel_changed(self):
819879
self.graph.set_kernel(self.kernel)
820880

881+
def __scale_changed(self):
882+
self.graph.set_scale(self.scale_index)
883+
821884
@property
822885
def kernel(self) -> str:
823886
# pylint: disable=invalid-sequence-index
@@ -841,7 +904,7 @@ def set_data(self, data: Optional[Table]):
841904
self.set_list_view_selection()
842905
self.apply_value_var_sorting()
843906
self.apply_group_var_sorting()
844-
self.enable_order_violins_cb()
907+
self.enable_controls()
845908
self.setup_plot()
846909
self.apply_selection()
847910

@@ -958,8 +1021,10 @@ def _ensure_selection_visible(view):
9581021
if len(selection) == 1:
9591022
view.scrollTo(selection[0])
9601023

961-
def enable_order_violins_cb(self):
962-
self._order_violins_cb.setEnabled(self.group_var is not None)
1024+
def enable_controls(self):
1025+
enable = self.group_var is not None or not self.data
1026+
self._order_violins_cb.setEnabled(enable)
1027+
self._scale_combo.setEnabled(enable)
9631028

9641029
def setup_plot(self):
9651030
self.graph.clear_plot()

0 commit comments

Comments
 (0)