Skip to content

Commit 171bb87

Browse files
committed
Violin Plot: Support scaling
1 parent c9934ae commit 171bb87

File tree

2 files changed

+94
-71
lines changed

2 files changed

+94
-71
lines changed

Orange/widgets/visualize/owviolinplot.py

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# pylint: disable=too-many-lines
2+
from collections import namedtuple
13
from itertools import chain, count
24
from typing import List, Optional, Tuple, Set, Sequence
35

@@ -8,7 +10,7 @@
810
from AnyQt.QtCore import QItemSelection, QPointF, QRectF, QSize, Qt, Signal
911
from AnyQt.QtGui import QBrush, QColor, QPainter, QPainterPath, QPolygonF
1012
from AnyQt.QtWidgets import QCheckBox, QSizePolicy, QGraphicsRectItem, \
11-
QGraphicsSceneMouseEvent, QApplication, QWidget
13+
QGraphicsSceneMouseEvent, QApplication, QWidget, QComboBox
1214

1315
import pyqtgraph as pg
1416

@@ -124,8 +126,12 @@ def fit_kernel(data: np.ndarray, kernel: str) -> \
124126

125127

126128
class ViolinItem(pg.GraphicsObject):
129+
RugPlot = namedtuple("RugPlot", "support, density")
130+
AREA, COUNT, WIDTH = range(3)
131+
127132
def __init__(self, data: np.ndarray, color: QColor, kernel: str,
128-
show_rug: bool, orientation: Qt.Orientations):
133+
scale: int, show_rug: bool, orientation: Qt.Orientations):
134+
self.__scale = scale
129135
self.__show_rug_plot = show_rug
130136
self.__orientation = orientation
131137

@@ -136,13 +142,13 @@ def __init__(self, data: np.ndarray, color: QColor, kernel: str,
136142
self.__violin_path: QPainterPath = self._create_violin(data)
137143
self.__violin_brush: QBrush = QBrush(color)
138144

139-
self.__rug_plot_data: Tuple = self._create_rug_plot(data)
145+
self.__rug_plot_data: ViolinItem.RugPlot = self._create_rug_plot(data)
140146

141147
super().__init__()
142148

143149
@property
144-
def kde(self) -> KernelDensity:
145-
return self.__kde
150+
def density(self) -> np.ndarray:
151+
return self.__rug_plot_data.density
146152

147153
@property
148154
def violin_width(self) -> float:
@@ -180,6 +186,7 @@ def _create_violin(self, data: np.ndarray) -> QPainterPath:
180186
x = np.linspace(data.min() - self.__bandwidth * 2,
181187
data.max() + self.__bandwidth * 2, 1000)
182188
p = np.exp(self.__kde.score_samples(x.reshape(-1, 1)))
189+
p = self.__apply_scaling(p, len(data))
183190

184191
if self.__orientation == Qt.Vertical:
185192
pts = [QPointF(pi, xi) for xi, pi in zip(x, p)]
@@ -194,13 +201,23 @@ def _create_violin(self, data: np.ndarray) -> QPainterPath:
194201
path.addPolygon(polygon)
195202
return path
196203

204+
def __apply_scaling(self, density: np.ndarray, n_data: int) -> np.ndarray:
205+
if self.__scale == self.AREA:
206+
return density
207+
elif self.__scale == self.COUNT:
208+
return density * n_data / density.max()
209+
elif self.__scale == self.WIDTH:
210+
return density / density.max()
211+
else:
212+
raise NotImplementedError
213+
197214
def _create_rug_plot(self, data: np.ndarray) -> Tuple:
198-
unique_data = np.unique(data)
199215
if self.__kde is None:
200-
return unique_data, np.zeros(unique_data.size)
216+
return self.RugPlot(data, np.zeros(data.size))
201217

202-
density = np.exp(self.__kde.score_samples(unique_data.reshape(-1, 1)))
203-
return unique_data, density
218+
density = np.exp(self.__kde.score_samples(data.reshape(-1, 1)))
219+
density = self.__apply_scaling(density, len(data))
220+
return self.RugPlot(data, density)
204221

205222

206223
class BoxItem(pg.GraphicsObject):
@@ -216,7 +233,7 @@ def __init__(self, data: np.ndarray, rect: QRectF,
216233
def boundingRect(self) -> QRectF:
217234
return self.__bounding_rect
218235

219-
def paint(self, painter: QPainter, _, widget: QWidget):
236+
def paint(self, painter: QPainter, _, widget: Optional[QWidget]):
220237
painter.save()
221238

222239
q0, q25, q75, q100 = self.__box_plot_data
@@ -227,7 +244,7 @@ def paint(self, painter: QPainter, _, widget: QWidget):
227244
quartile1 = QPointF(q0, 0), QPointF(q100, 0)
228245
quartile2 = QPointF(q25, 0), QPointF(q75, 0)
229246

230-
factor = widget.devicePixelRatio()
247+
factor = 1 if widget is None else widget.devicePixelRatio()
231248
painter.setPen(pg.mkPen(QColor(Qt.black), width=2 * factor))
232249
painter.drawLine(*quartile1)
233250
painter.setPen(pg.mkPen(QColor(Qt.black), width=6 * factor))
@@ -261,13 +278,9 @@ def value(self) -> float:
261278

262279

263280
class StripItem(pg.ScatterPlotItem):
264-
def __init__(self, data: np.ndarray, kde: KernelDensity,
281+
def __init__(self, data: np.ndarray, density: np.ndarray,
265282
color: QColor, orientation: Qt.Orientations):
266-
if kde is not None:
267-
lim = np.exp(kde.score_samples(data.reshape(-1, 1)))
268-
else:
269-
lim = np.zeros(data.size)
270-
x = np.random.RandomState(0).uniform(-lim, lim)
283+
x = np.random.RandomState(0).uniform(-density, density)
271284
x, y = (x, data) if orientation == Qt.Vertical else (data, x)
272285
color = color.lighter(150)
273286
super().__init__(x=x, y=y, size=5, brush=pg.mkBrush(color))
@@ -317,7 +330,7 @@ class ViolinPlot(pg.PlotWidget):
317330
VIOLIN_PADDING_FACTOR = 1.25
318331
selection_changed = Signal(list, list)
319332

320-
def __init__(self, parent: OWWidget, kernel: str,
333+
def __init__(self, parent: OWWidget, kernel: str, scale: int,
321334
orientation: Qt.Orientations, show_box_plot: bool,
322335
show_strip_plot: bool, show_rug_plot: bool, sort_items: bool):
323336

@@ -329,6 +342,7 @@ def __init__(self, parent: OWWidget, kernel: str,
329342

330343
# settings
331344
self.__kernel = kernel
345+
self.__scale = scale
332346
self.__orientation = orientation
333347
self.__show_box_plot = show_box_plot
334348
self.__show_strip_plot = show_strip_plot
@@ -396,6 +410,11 @@ def set_kernel(self, kernel: str):
396410
self.__kernel = kernel
397411
self._plot_data()
398412

413+
def set_scale(self, scale: int):
414+
if self.__scale != scale:
415+
self.__scale = scale
416+
self._plot_data()
417+
399418
def set_orientation(self, orientation: Qt.Orientations):
400419
if self.__orientation != orientation:
401420
self.__orientation = orientation
@@ -491,6 +510,9 @@ def _set_axes(self):
491510
self.getAxis("left" if vertical else "bottom").setLabel(value_title)
492511
self.getAxis("bottom" if vertical else "left").setLabel(group_title)
493512

513+
if self.__group_var is None:
514+
self.getAxis("bottom" if vertical else "left").setTicks([])
515+
494516
def _plot_data(self):
495517
# save selection ranges
496518
ranges = self._selection_ranges
@@ -516,7 +538,7 @@ def _plot_data(self):
516538
def _set_violin_item(self, values: np.ndarray, color: QColor):
517539
values = values[~np.isnan(values)]
518540

519-
violin = ViolinItem(values, color, self.__kernel,
541+
violin = ViolinItem(values, color, self.__kernel, self.__scale,
520542
self.__show_rug_plot, self.__orientation)
521543
self.addItem(violin)
522544
self.__violin_items.append(violin)
@@ -531,7 +553,7 @@ def _set_violin_item(self, values: np.ndarray, color: QColor):
531553
self.addItem(median)
532554
self.__median_items.append(median)
533555

534-
strip = StripItem(values, violin.kde, color, self.__orientation)
556+
strip = StripItem(values, violin.density, color, self.__orientation)
535557
strip.setVisible(self.__show_strip_plot)
536558
self.addItem(strip)
537559
self.__strip_items.append(strip)
@@ -658,7 +680,8 @@ class Error(OWWidget.Error):
658680
not_enough_instances = Msg("Plotting requires at least two instances.")
659681

660682
KERNELS = ["gaussian", "epanechnikov", "linear"]
661-
KERNEL_LABELS = ["Normal kernel", "Epanechnikov kernel", "Linear kernel"]
683+
KERNEL_LABELS = ["Normal", "Epanechnikov", "Linear"]
684+
SCALE_LABELS = ["Area", "Count", "Width"]
662685

663686
settingsHandler = DomainContextHandler()
664687
value_var = ContextSetting(None)
@@ -671,6 +694,7 @@ class Error(OWWidget.Error):
671694
order_violins = Setting(False)
672695
orientation_index = Setting(1) # Vertical
673696
kernel_index = Setting(0) # Normal kernel
697+
scale_index = Setting(0) # Area
674698
selection_ranges = Setting([], schema_only=True)
675699
visual_settings = Setting({}, schema_only=True)
676700

@@ -686,6 +710,7 @@ def __init__(self):
686710
self._value_var_view: ListViewSearch = None
687711
self._group_var_view: ListViewSearch = None
688712
self._order_violins_cb: QCheckBox = None
713+
self._scale_combo: QComboBox = None
689714
self.selection = []
690715
self.__pending_selection: List = self.selection_ranges
691716

@@ -700,7 +725,8 @@ def setup_gui(self):
700725

701726
def _add_graph(self):
702727
box = gui.vBox(self.mainArea)
703-
self.graph = ViolinPlot(self, self.kernel, self.orientation,
728+
self.graph = ViolinPlot(self, self.kernel,
729+
self.scale_index, self.orientation,
704730
self.show_box_plot, self.show_strip_plot,
705731
self.show_rug_plot, self.order_violins)
706732
self.graph.selection_changed.connect(self.__selection_changed)
@@ -754,8 +780,7 @@ def _add_controls(self):
754780
callback=self.apply_group_var_sorting)
755781

756782
box = gui.vBox(self.controlArea, "Display",
757-
sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum),
758-
addSpace=False)
783+
sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum))
759784
gui.checkBox(box, self, "show_box_plot", "Box plot",
760785
callback=self.__show_box_plot_changed)
761786
gui.checkBox(box, self, "show_strip_plot", "Strip plot",
@@ -772,13 +797,15 @@ def _add_controls(self):
772797
callback=self.__orientation_changed)
773798

774799
box = gui.vBox(self.controlArea, "Density Estimation",
775-
sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum),
776-
addSpace=False)
800+
sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum))
777801
gui.comboBox(box, self, "kernel_index", items=self.KERNEL_LABELS,
802+
label="Kernel:", labelWidth=60, orientation=Qt.Horizontal,
778803
callback=self.__kernel_changed)
779-
780-
# stretch over buttonsArea
781-
self.left_side.layout().setStretch(0, 9999999)
804+
self._scale_combo = gui.comboBox(
805+
box, self, "scale_index", items=self.SCALE_LABELS,
806+
label="Scale:", labelWidth=60, orientation=Qt.Horizontal,
807+
callback=self.__scale_changed
808+
)
782809

783810
self._set_input_summary(None)
784811
self._set_output_summary(None)
@@ -796,7 +823,7 @@ def __group_var_changed(self, selection: QItemSelection):
796823
return
797824
self.group_var = selection.indexes()[0].data(gui.TableVariable)
798825
self.apply_value_var_sorting()
799-
self.enable_order_violins_cb()
826+
self.enable_controls()
800827
self.setup_plot()
801828
self.__selection_changed([], [])
802829

@@ -818,6 +845,9 @@ def __orientation_changed(self):
818845
def __kernel_changed(self):
819846
self.graph.set_kernel(self.kernel)
820847

848+
def __scale_changed(self):
849+
self.graph.set_scale(self.scale_index)
850+
821851
@property
822852
def kernel(self) -> str:
823853
# pylint: disable=invalid-sequence-index
@@ -841,7 +871,7 @@ def set_data(self, data: Optional[Table]):
841871
self.set_list_view_selection()
842872
self.apply_value_var_sorting()
843873
self.apply_group_var_sorting()
844-
self.enable_order_violins_cb()
874+
self.enable_controls()
845875
self.setup_plot()
846876
self.apply_selection()
847877

@@ -958,8 +988,10 @@ def _ensure_selection_visible(view):
958988
if len(selection) == 1:
959989
view.scrollTo(selection[0])
960990

961-
def enable_order_violins_cb(self):
962-
self._order_violins_cb.setEnabled(self.group_var is not None)
991+
def enable_controls(self):
992+
enable = self.group_var is not None or not self.data
993+
self._order_violins_cb.setEnabled(enable)
994+
self._scale_combo.setEnabled(enable)
963995

964996
def setup_plot(self):
965997
self.graph.clear_plot()

0 commit comments

Comments
 (0)