Skip to content

Commit c9934ae

Browse files
committed
Violin Plot: Draw all instances inside violin
1 parent aaf3357 commit c9934ae

File tree

2 files changed

+62
-83
lines changed

2 files changed

+62
-83
lines changed

Orange/widgets/visualize/owviolinplot.py

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from collections import namedtuple
21
from itertools import chain, count
3-
from typing import List, Optional, Tuple, Set
2+
from typing import List, Optional, Tuple, Set, Sequence
43

54
import numpy as np
65
from scipy import stats
@@ -95,11 +94,11 @@ def update_ticks(**settings):
9594
}
9695

9796
@property
98-
def title_item(self):
97+
def title_item(self) -> pg.LabelItem:
9998
return self.master.getPlotItem().titleLabel
10099

101100
@property
102-
def axis_items(self):
101+
def axis_items(self) -> List[pg.AxisItem]:
103102
return [value["item"] for value in
104103
self.master.getPlotItem().axes.values()]
105104

@@ -125,8 +124,6 @@ def fit_kernel(data: np.ndarray, kernel: str) -> \
125124

126125

127126
class ViolinItem(pg.GraphicsObject):
128-
RugPlot = namedtuple("RugPlot", "support, density")
129-
130127
def __init__(self, data: np.ndarray, color: QColor, kernel: str,
131128
show_rug: bool, orientation: Qt.Orientations):
132129
self.__show_rug_plot = show_rug
@@ -139,12 +136,16 @@ def __init__(self, data: np.ndarray, color: QColor, kernel: str,
139136
self.__violin_path: QPainterPath = self._create_violin(data)
140137
self.__violin_brush: QBrush = QBrush(color)
141138

142-
self.__rug_plot_data: ViolinItem.RugPlot = self._create_rug_plot(data)
139+
self.__rug_plot_data: Tuple = self._create_rug_plot(data)
143140

144141
super().__init__()
145142

146143
@property
147-
def violin_width(self):
144+
def kde(self) -> KernelDensity:
145+
return self.__kde
146+
147+
@property
148+
def violin_width(self) -> float:
148149
return self.boundingRect().width() if self.__orientation == Qt.Vertical \
149150
else self.boundingRect().height()
150151

@@ -196,21 +197,19 @@ def _create_violin(self, data: np.ndarray) -> QPainterPath:
196197
def _create_rug_plot(self, data: np.ndarray) -> Tuple:
197198
unique_data = np.unique(data)
198199
if self.__kde is None:
199-
return self.RugPlot(unique_data, np.zeros(unique_data.size))
200+
return unique_data, np.zeros(unique_data.size)
200201

201202
density = np.exp(self.__kde.score_samples(unique_data.reshape(-1, 1)))
202-
return self.RugPlot(unique_data, density)
203+
return unique_data, density
203204

204205

205206
class BoxItem(pg.GraphicsObject):
206-
Stats = namedtuple("Stats", "min q25 q75 max")
207-
208207
def __init__(self, data: np.ndarray, rect: QRectF,
209208
orientation: Qt.Orientations):
210209
self.__bounding_rect = rect
211210
self.__orientation = orientation
212211

213-
self.__box_plot_data: BoxItem.Stats = self._create_box_plot(data)
212+
self.__box_plot_data: Tuple = self._create_box_plot(data)
214213

215214
super().__init__()
216215

@@ -239,13 +238,13 @@ def paint(self, painter: QPainter, _, widget: QWidget):
239238
@staticmethod
240239
def _create_box_plot(data: np.ndarray) -> Tuple:
241240
if data.size == 0:
242-
return BoxItem.Stats(*[0] * 4)
241+
return (0,) * 4
243242

244243
q25, q75 = np.percentile(data, [25, 75])
245244
whisker_lim = 1.5 * stats.iqr(data)
246245
min_ = np.min(data[data >= (q25 - whisker_lim)])
247246
max_ = np.max(data[data <= (q75 + whisker_lim)])
248-
return BoxItem.Stats(min_, q25, q75, max_)
247+
return min_, q25, q75, max_
249248

250249

251250
class MedianItem(pg.ScatterPlotItem):
@@ -257,14 +256,18 @@ def __init__(self, data: np.ndarray, orientation: Qt.Orientations):
257256
brush=pg.mkBrush(QColor(Qt.white)))
258257

259258
@property
260-
def value(self):
259+
def value(self) -> float:
261260
return self.__value
262261

263262

264263
class StripItem(pg.ScatterPlotItem):
265-
def __init__(self, data: np.ndarray, lim: float, color: QColor,
266-
orientation: Qt.Orientations):
267-
x = np.random.RandomState(0).uniform(-lim, lim, data.size)
264+
def __init__(self, data: np.ndarray, kde: KernelDensity,
265+
color: QColor, orientation: Qt.Orientations):
266+
if kde is not None:
267+
lim = np.exp(kde.score_samples(data.reshape(-1, 1)))
268+
else:
269+
lim = np.zeros(data.size)
270+
x = np.random.RandomState(0).uniform(-lim, lim)
268271
x, y = (x, data) if orientation == Qt.Vertical else (data, x)
269272
color = color.lighter(150)
270273
super().__init__(x=x, y=y, size=5, brush=pg.mkBrush(color))
@@ -274,27 +277,28 @@ class SelectionRect(pg.GraphicsObject):
274277
def __init__(self, rect: QRectF, orientation: Qt.Orientations):
275278
self.__rect: QRectF = rect
276279
self.__orientation: Qt.Orientations = orientation
277-
self.__selection_range: Optional[Tuple[float]] = None
280+
self.__selection_range: Optional[Tuple[float, float]] = None
278281
super().__init__()
279282

280283
@property
281-
def selection_range(self):
284+
def selection_range(self) -> Optional[Tuple[float, float]]:
282285
return self.__selection_range
283286

284287
@selection_range.setter
285-
def selection_range(self, selection_range: Optional[Tuple[float]]):
288+
def selection_range(self, selection_range: Optional[Tuple[float, float]]):
286289
self.__selection_range = selection_range
287290
self.update()
288291

289292
@property
290-
def selection_rect(self):
293+
def selection_rect(self) -> QRectF:
291294
rect: QRectF = self.__rect
292-
if self.__orientation == Qt.Vertical:
293-
rect.setTop(self.__selection_range[0])
294-
rect.setBottom(self.__selection_range[1])
295-
else:
296-
rect.setLeft(self.__selection_range[0])
297-
rect.setRight(self.__selection_range[1])
295+
if self.__selection_range is not None:
296+
if self.__orientation == Qt.Vertical:
297+
rect.setTop(self.__selection_range[0])
298+
rect.setBottom(self.__selection_range[1])
299+
else:
300+
rect.setLeft(self.__selection_range[0])
301+
rect.setRight(self.__selection_range[1])
298302
return rect
299303

300304
def boundingRect(self) -> QRectF:
@@ -318,8 +322,8 @@ def __init__(self, parent: OWWidget, kernel: str,
318322
show_strip_plot: bool, show_rug_plot: bool, sort_items: bool):
319323

320324
# data
321-
self.__values: np.ndarray = None
322-
self.__value_var: ContinuousVariable = None
325+
self.__values: Optional[np.ndarray] = None
326+
self.__value_var: Optional[ContinuousVariable] = None
323327
self.__group_values: Optional[np.ndarray] = None
324328
self.__group_var: Optional[DiscreteVariable] = None
325329

@@ -356,22 +360,22 @@ def __init__(self, parent: OWWidget, kernel: str,
356360
self.parameter_setter = ParameterSetter(self)
357361

358362
@property
359-
def _selection_ranges(self) -> List[Optional[Tuple[float]]]:
363+
def _selection_ranges(self) -> List[Optional[Tuple[float, float]]]:
360364
return [rect.selection_range for rect in self.__selection_rects]
361365

362366
@_selection_ranges.setter
363-
def _selection_ranges(self, ranges: List[Optional[Tuple[float]]]):
367+
def _selection_ranges(self, ranges: List[Optional[Tuple[float, float]]]):
364368
for min_max, sel_rect in zip(ranges, self.__selection_rects):
365369
sel_rect.selection_range = min_max
366370

367371
@property
368-
def _sorted_group_indices(self):
372+
def _sorted_group_indices(self) -> Sequence[int]:
369373
medians = [item.value for item in self.__median_items]
370374
return np.argsort(medians) if self.__sort_items \
371375
else range(len(medians))
372376

373377
@property
374-
def _max_item_width(self):
378+
def _max_item_width(self) -> float:
375379
if not self.__violin_items:
376380
return 0
377381
return max(item.violin_width * self.VIOLIN_PADDING_FACTOR
@@ -459,7 +463,10 @@ def order_items(self):
459463
for i, index in enumerate(indices)]]
460464
self.getAxis(side).setTicks(ticks)
461465

462-
def set_selection(self, ranges: List[Optional[Tuple[float]]]):
466+
def set_selection(self, ranges: List[Optional[Tuple[float, float]]]):
467+
if self.__values is None:
468+
return
469+
463470
self._selection_ranges = ranges
464471

465472
self.__selection = set()
@@ -524,9 +531,7 @@ def _set_violin_item(self, values: np.ndarray, color: QColor):
524531
self.addItem(median)
525532
self.__median_items.append(median)
526533

527-
br = violin.boundingRect()
528-
lim = br.width() if self.__orientation == Qt.Vertical else br.height()
529-
strip = StripItem(values, lim / 2, color, self.__orientation)
534+
strip = StripItem(values, violin.kde, color, self.__orientation)
530535
strip.setVisible(self.__show_strip_plot)
531536
self.addItem(strip)
532537
self.__strip_items.append(strip)
@@ -814,12 +819,13 @@ def __kernel_changed(self):
814819
self.graph.set_kernel(self.kernel)
815820

816821
@property
817-
def kernel(self):
822+
def kernel(self) -> str:
818823
# pylint: disable=invalid-sequence-index
819824
return self.KERNELS[self.kernel_index]
820825

821826
@property
822-
def orientation(self):
827+
def orientation(self) -> Qt.Orientations:
828+
# pylint: disable=invalid-sequence-index
823829
return [Qt.Horizontal, Qt.Vertical][self.orientation_index]
824830

825831
@Inputs.data
@@ -1011,6 +1017,7 @@ def send_report(self):
10111017

10121018
def set_visual_settings(self, key: KeyType, value: ValueType):
10131019
self.graph.parameter_setter.set_parameter(key, value)
1020+
# pylint: disable=unsupported-assignment-operation
10141021
self.visual_settings[key] = value
10151022

10161023

Orange/widgets/visualize/tests/test_owviolinplot.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
# pylint: disable=(protected-access
1+
# pylint: disable=protected-access
22
import unittest
33
from unittest.mock import patch
44

5-
import numpy as np
6-
from scipy import stats
7-
from sklearn.neighbors import KernelDensity
85
import matplotlib.pyplot as plt
96
import seaborn as sns
107

@@ -27,9 +24,7 @@ def setUpClass(cls):
2724

2825
cls.signal_name = "Data"
2926
cls.signal_data = cls.data
30-
cls.titanic = Table("titanic")
3127
cls.housing = Table("housing")
32-
cls.heart = Table("heart_disease")
3328

3429
def setUp(self):
3530
self.widget = self.create_widget(OWViolinPlot)
@@ -111,8 +106,7 @@ def test_datasets(self):
111106
self.send_signal(self.widget.Inputs.data, ds)
112107

113108
def test_selection_no_group(self):
114-
data = Table("housing")
115-
self.send_signal(self.widget.Inputs.data, data)
109+
self.send_signal(self.widget.Inputs.data, self.housing)
116110
self.widget.graph._update_selection(QPointF(0, 30), QPointF(0, 40), 1)
117111
selected = self.get_output(self.widget.Outputs.selected_data)
118112
self.assertEqual(len(selected), 53)
@@ -131,8 +125,7 @@ def test_selection_sort_violins(self):
131125
self.assert_table_equal(selected1, selected2)
132126

133127
def test_selection_orientation(self):
134-
data = Table("housing")
135-
self.send_signal(self.widget.Inputs.data, data)
128+
self.send_signal(self.widget.Inputs.data, self.housing)
136129
self.widget.graph._update_selection(QPointF(0, 30), QPointF(0, 40), 1)
137130
self.widget.controls.orientation_index.buttons[0].click()
138131
selected = self.get_output(self.widget.Outputs.selected_data)
@@ -214,12 +207,13 @@ def assertFontEqual(self, font1, font2):
214207
self.assertEqual(font1.pointSize(), font2.pointSize())
215208
self.assertEqual(font1.italic(), font2.italic())
216209

217-
def __select_value(self, list, value):
218-
m = list.model()
219-
for i in range(m.rowCount()):
220-
idx = m.index(i, 0)
221-
if m.data(idx) == value:
222-
list.selectionModel().setCurrentIndex(
210+
@staticmethod
211+
def __select_value(list_, value):
212+
model = list_.model()
213+
for i in range(model.rowCount()):
214+
idx = model.index(i, 0)
215+
if model.data(idx) == value:
216+
list_.selectionModel().setCurrentIndex(
223217
idx, QItemSelectionModel.ClearAndSelect)
224218

225219
def test_seaborn(self):
@@ -240,41 +234,19 @@ def test_seaborn(self):
240234
# hue = df["chest pain"]
241235
print(y.min(), y.max())
242236
sns.violinplot(
243-
x=y,
244-
y=x,
237+
x=x,
238+
y=y,
245239
# inner="stick",
246-
orient="h",
240+
# orient="h",
247241
# hue=hue,
248-
# scale="count",
242+
scale="count",
249243
# data=df,
250244
# split=True,
251245
)
252246
plt.show()
253247

254248
sns.kdeplot()
255249

256-
def test_sklearn(self):
257-
table = Table("iris")
258-
self.assertEqual(True, False)
259-
260-
data = table.X[:, 0]
261-
kde = stats.gaussian_kde(data)
262-
bw = kde.factor * data.std(ddof=1)
263-
print(bw)
264-
# bw = 1
265-
kde = stats.gaussian_kde(data, bw_method=bw / data.std(ddof=1))
266-
support1 = np.linspace(data.min() - bw * 2, data.max() + bw * 2, 100)
267-
density1 = kde.evaluate(support1)
268-
269-
data = table.X[:, 0]
270-
kde = KernelDensity(bandwidth=bw, kernel="gaussian")
271-
kde.fit(data.reshape(-1, 1))
272-
support2 = np.linspace(data.min() - bw * 2, data.max() + bw * 2, 100)
273-
density2 = np.exp(kde.score_samples(support2.reshape(-1, 1)))
274-
275-
np.testing.assert_array_equal(support1, support2)
276-
np.testing.assert_array_almost_equal(density1, density2)
277-
278250

279251
if __name__ == "__main__":
280252
unittest.main()

0 commit comments

Comments
 (0)