Skip to content

Commit 514ae98

Browse files
authored
Merge pull request #3871 from markotoplak/faster-scatterplot
[ENH] Faster drawing in scatterplot
2 parents 0844614 + 0dde1ab commit 514ae98

File tree

2 files changed

+92
-19
lines changed

2 files changed

+92
-19
lines changed

Orange/widgets/visualize/owscatterplotgraph.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,20 @@ def __init__(self, scatter_widget, parent=None, _="None", view_box=InteractiveVi
221221

222222

223223
class ScatterPlotItem(pg.ScatterPlotItem):
224+
"""PyQtGraph's ScatterPlotItem calls updateSpots at any change of sizes/colors/symbols,
225+
which then rebuilds the stored pixmaps for each symbol. Because Orange calls
226+
set* function in succession, we postpone updateSpots() to paint()."""
227+
228+
_update_spots_in_paint = False
229+
230+
def updateSpots(self, dataSet=None): # pylint: disable=unused-argument
231+
self._update_spots_in_paint = True
232+
self.update()
233+
224234
def paint(self, painter, option, widget=None):
235+
if self._update_spots_in_paint:
236+
self._update_spots_in_paint = False
237+
super().updateSpots()
225238
painter.setRenderHint(QPainter.SmoothPixmapTransform, True)
226239
super().paint(painter, option, widget)
227240

@@ -714,6 +727,7 @@ def get_sizes(self):
714727
715728
The method is called by `update_sizes`. It gets the sizes
716729
from the widget and performs the necessary scaling and sizing.
730+
The output is rounded to half a pixel for faster drawing.
717731
718732
Returns:
719733
(np.ndarray): sizes
@@ -732,7 +746,11 @@ def get_sizes(self):
732746
size_column /= mx
733747
else:
734748
size_column[:] = 0.5
735-
return self.MinShapeSize + (5 + self.point_width) * size_column
749+
750+
sizes = self.MinShapeSize + (5 + self.point_width) * size_column
751+
# round sizes to half pixel for smaller pyqtgraph's symbol pixmap atlas
752+
sizes = (sizes * 2).round() / 2
753+
return sizes
736754

737755
def update_sizes(self):
738756
"""
@@ -861,7 +879,7 @@ def _get_same_colors(self, subset):
861879
(tuple): a list of pens and list of brushes
862880
"""
863881
color = self.plot_widget.palette().color(OWPalette.Data)
864-
pen = [_make_pen(color, 1.5) for _ in range(self.n_shown)]
882+
pen = [_make_pen(color, 1.5)] * self.n_shown # use a single QPen instance
865883
if subset is not None:
866884
brush = np.where(
867885
subset,
@@ -870,7 +888,7 @@ def _get_same_colors(self, subset):
870888
else:
871889
color = QColor(*self.COLOR_DEFAULT)
872890
color.setAlpha(self.alpha_value)
873-
brush = [QBrush(color) for _ in range(self.n_shown)]
891+
brush = [QBrush(color)] * self.n_shown # use a single QBrush instance
874892
return pen, brush
875893

876894
def _get_continuous_colors(self, c_data, subset):
@@ -894,12 +912,31 @@ def _get_continuous_colors(self, c_data, subset):
894912
[pen, np.full((len(pen), 1), self.alpha_value, dtype=int)])
895913
pen *= 100
896914
pen //= self.DarkerValue
897-
pen = [_make_pen(QColor(*col), 1.5) for col in pen.tolist()]
915+
916+
# Reuse pens and brushes with the same colors because PyQtGraph then builds
917+
# smaller pixmap atlas, which makes the drawing faster
918+
919+
def reuse(cache, fn, *args):
920+
if args not in cache:
921+
cache[args] = fn(args)
922+
return cache[args]
923+
924+
def create_pen(col):
925+
return _make_pen(QColor(*col), 1.5)
926+
927+
def create_brush(col):
928+
return QBrush(QColor(*col))
929+
930+
cached_pens = {}
931+
pen = [reuse(cached_pens, create_pen, *col) for col in pen.tolist()]
898932

899933
if subset is not None:
900934
brush[:, 3] = 0
901935
brush[subset, 3] = 255
902-
brush = np.array([QBrush(QColor(*col)) for col in brush.tolist()])
936+
937+
cached_brushes = {}
938+
brush = np.array([reuse(cached_brushes, create_brush, *col) for col in brush.tolist()])
939+
903940
return pen, brush
904941

905942
def _get_discrete_colors(self, c_data, subset):

Orange/widgets/visualize/tests/test_owscatterplotbase.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,9 @@ def test_sampling(self):
202202
scatterplot_item = graph.scatterplot_item
203203
x, y = scatterplot_item.getData()
204204
data = scatterplot_item.data
205-
s0, s1, s2, s3 = data["size"] - graph.MinShapeSize
206-
np.testing.assert_almost_equal(
207-
(s2 - s1) / (s1 - s0),
208-
(x[2] - x[1]) / (x[1] - x[0]))
209-
np.testing.assert_almost_equal(
210-
(s2 - s1) / (s1 - s3),
211-
(x[2] - x[1]) / (x[1] - x[3]))
205+
s = data["size"] - graph.MinShapeSize
206+
precise_s = (x - min(x)) / (max(x) - min(x)) * max(s)
207+
np.testing.assert_almost_equal(s, precise_s, decimal=0)
212208
self.assertEqual(
213209
list(data["symbol"]),
214210
[graph.CurveSymbols[int(xi)] for xi in x])
@@ -358,16 +354,24 @@ def test_size_normalization(self):
358354
graph.reset_graph()
359355
scatterplot_item = graph.scatterplot_item
360356
size = scatterplot_item.data["size"]
361-
diffs = [round(y - x, 2) for x, y in zip(size, size[1:])]
362-
self.assertEqual(len(set(diffs)), 1)
363-
self.assertGreater(diffs[0], 0)
357+
np.testing.assert_equal(size, [6, 7.5, 9.5, 11, 12.5, 14.5, 16, 17.5, 19.5, 21])
364358

365359
d = np.arange(10, 20, dtype=float)
366360
graph.update_sizes()
367361
self.assertIs(scatterplot_item, graph.scatterplot_item)
362+
size2 = scatterplot_item.data["size"]
363+
np.testing.assert_equal(size, size2)
364+
365+
def test_size_rounding_half_pixel(self):
366+
graph = self.graph
367+
368+
self.master.get_size_data = lambda: d
369+
d = np.arange(10, dtype=float)
370+
371+
graph.reset_graph()
372+
scatterplot_item = graph.scatterplot_item
368373
size = scatterplot_item.data["size"]
369-
diffs2 = [round(y - x, 2) for x, y in zip(size, size[1:])]
370-
self.assertEqual(diffs, diffs2)
374+
np.testing.assert_equal(size*2 - (size*2).round(), 0)
371375

372376
def test_size_with_nans(self):
373377
graph = self.graph
@@ -493,12 +497,17 @@ def test_colors_discrete(self):
493497
d = np.arange(10, dtype=float) % 2
494498

495499
graph.reset_graph()
500+
data = graph.scatterplot_item.data
496501
self.assertTrue(
497502
all(pen.color().hue() is palette[i % 2].hue()
498-
for i, pen in enumerate(graph.scatterplot_item.data["pen"])))
503+
for i, pen in enumerate(data["pen"])))
499504
self.assertTrue(
500505
all(pen.color().hue() is palette[i % 2].hue()
501-
for i, pen in enumerate(graph.scatterplot_item.data["brush"])))
506+
for i, pen in enumerate(data["brush"])))
507+
508+
# confirm that QPen/QBrush were reused
509+
self.assertEqual(len(set(map(id, data["pen"]))), 2)
510+
self.assertEqual(len(set(map(id, data["brush"]))), 2)
502511

503512
def test_colors_discrete_nan(self):
504513
self.master.is_continuous_color = lambda: False
@@ -529,6 +538,24 @@ def test_colors_continuous(self):
529538
d[4] = np.nan
530539
graph.update_colors() # Ditto
531540

541+
def test_colors_continuous_reused(self):
542+
self.master.is_continuous_color = lambda: True
543+
graph = self.graph
544+
545+
self.xy = (np.arange(100, dtype=float),
546+
np.arange(100, dtype=float))
547+
548+
d = np.arange(100, dtype=float)
549+
self.master.get_color_data = lambda: d
550+
graph.reset_graph()
551+
552+
data = graph.scatterplot_item.data
553+
554+
self.assertEqual(len(data["pen"]), 100)
555+
self.assertLessEqual(len(set(map(id, data["pen"]))), 10)
556+
self.assertEqual(len(data["brush"]), 100)
557+
self.assertLessEqual(len(set(map(id, data["brush"]))), 10)
558+
532559
def test_colors_continuous_nan(self):
533560
self.master.is_continuous_color = lambda: True
534561
graph = self.graph
@@ -602,12 +629,16 @@ def test_colors_none(self):
602629
data = graph.scatterplot_item.data
603630
self.assertTrue(all(pen.color().hue() == hue for pen in data["pen"]))
604631
self.assertTrue(all(pen.color().hue() == hue for pen in data["brush"]))
632+
self.assertEqual(len(set(map(id, data["pen"]))), 1) # test QPen/QBrush reuse
633+
self.assertEqual(len(set(map(id, data["brush"]))), 1)
605634

606635
self.master.get_subset_mask = lambda: np.arange(10) < 5
607636
graph.update_colors()
608637
data = graph.scatterplot_item.data
609638
self.assertTrue(all(pen.color().hue() == hue for pen in data["pen"]))
610639
self.assertTrue(all(pen.color().hue() == hue for pen in data["brush"]))
640+
self.assertEqual(len(set(map(id, data["pen"]))), 1)
641+
self.assertEqual(len(set(map(id, data["brush"]))), 2) # transparent and colored
611642

612643
def test_colors_update_legend_and_density(self):
613644
graph = self.graph
@@ -1247,6 +1278,11 @@ def test_hiding_too_many_labels(self):
12471278
self.assertFalse(spy[-1][0])
12481279
self.assertFalse(bool(self.graph.labels))
12491280

1281+
def test_no_needless_buildatlas(self):
1282+
graph = self.graph
1283+
graph.reset_graph()
1284+
self.assertIsNone(graph.scatterplot_item.fragmentAtlas.atlas)
1285+
12501286

12511287
if __name__ == "__main__":
12521288
import unittest

0 commit comments

Comments
 (0)