Skip to content

Commit ff8e05f

Browse files
committed
Scatterplot: Add tests for regression lines
1 parent 41a4a56 commit ff8e05f

File tree

1 file changed

+274
-3
lines changed

1 file changed

+274
-3
lines changed

Orange/widgets/visualize/tests/test_owscatterplot.py

Lines changed: 274 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring,too-many-public-methods,protected-access
3+
# pylint: disable=too-many-lines
34
from unittest.mock import MagicMock, patch, Mock
45
import numpy as np
56

67
from AnyQt.QtCore import QRectF, Qt
78
from AnyQt.QtWidgets import QToolTip
9+
from AnyQt.QtGui import QColor
810

911
from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable
1012
from Orange.widgets.tests.base import (
1113
WidgetTest, WidgetOutputsTestMixin, datasets, ProjectionWidgetTestMixin
1214
)
1315
from Orange.widgets.tests.utils import simulate
16+
from Orange.widgets.utils.colorpalette import DefaultRGBColors
1417
from Orange.widgets.visualize.owscatterplot import (
15-
OWScatterPlot, ScatterPlotVizRank
16-
)
18+
OWScatterPlot, ScatterPlotVizRank, OWScatterPlotGraph)
1719
from Orange.widgets.visualize.utils.widget import MAX_CATEGORIES
1820
from Orange.widgets.widget import AttributeList
1921

@@ -735,12 +737,281 @@ def test_on_manual_change(self):
735737
selection = vizrank.rank_table.selectedIndexes()
736738
self.assertEqual(len(selection), 0)
737739

738-
def test_regression_line(self):
740+
def test_regression_lines_appear(self):
739741
self.widget.graph.controls.show_reg_line.setChecked(True)
742+
self.assertEqual(len(self.widget.graph.reg_line_items), 0)
740743
self.send_signal(self.widget.Inputs.data, self.data)
744+
self.assertEqual(len(self.widget.graph.reg_line_items), 4)
745+
simulate.combobox_activate_index(self.widget.controls.attr_color, 0)
746+
self.assertEqual(len(self.widget.graph.reg_line_items), 1)
741747
data = self.data.copy()
742748
data[:, 0] = np.nan
743749
self.send_signal(self.widget.Inputs.data, data)
750+
self.assertEqual(len(self.widget.graph.reg_line_items), 0)
751+
752+
def test_regression_line_coeffs(self):
753+
widget = self.widget
754+
graph = widget.graph
755+
xy = np.array([[0, 0], [1, 0], [1, 2], [2, 2],
756+
[0, 1], [1, 3], [2, 5]], dtype=np.float)
757+
colors = np.array([0, 0, 0, 0, 1, 1, 1], dtype=np.float)
758+
widget.get_coordinates_data = lambda: xy.T
759+
widget.get_color_data = lambda: colors
760+
widget.is_continuous_color = lambda: False
761+
graph.palette = DefaultRGBColors
762+
graph.controls.show_reg_line.setChecked(True)
763+
764+
graph.update_regression_line()
765+
766+
line1 = graph.reg_line_items[1]
767+
self.assertEqual(line1.pos().x(), 0)
768+
self.assertEqual(line1.pos().y(), 0)
769+
self.assertEqual(line1.angle, 45)
770+
self.assertEqual(line1.pen.color().getRgb()[:3], graph.palette[0])
771+
772+
line2 = graph.reg_line_items[2]
773+
self.assertEqual(line2.pos().x(), 0)
774+
self.assertEqual(line2.pos().y(), 1)
775+
self.assertAlmostEqual(line2.angle, np.degrees(np.arctan2(2, 1)))
776+
self.assertEqual(line2.pen.color().getRgb()[:3], graph.palette[1])
777+
778+
graph.orthonormal_regression = True
779+
graph.update_regression_line()
780+
781+
line1 = graph.reg_line_items[1]
782+
self.assertEqual(line1.pos().x(), 0)
783+
self.assertAlmostEqual(line1.pos().y(), -0.6180339887498949)
784+
self.assertEqual(line1.angle, 58.28252558853899)
785+
self.assertEqual(line1.pen.color().getRgb()[:3], graph.palette[0])
786+
787+
line2 = graph.reg_line_items[2]
788+
self.assertEqual(line2.pos().x(), 0)
789+
self.assertEqual(line2.pos().y(), 1)
790+
self.assertAlmostEqual(line2.angle, np.degrees(np.arctan2(2, 1)))
791+
self.assertEqual(line2.pen.color().getRgb()[:3], graph.palette[1])
792+
793+
def test_orthonormal_line(self):
794+
color = QColor(1, 2, 3)
795+
width = 42
796+
# Normal line
797+
line = OWScatterPlotGraph._orthonormal_line(
798+
np.array([0, 1, 1, 2]), np.array([0, 0, 2, 2]), color, width)
799+
self.assertEqual(line.pos().x(), 0)
800+
self.assertAlmostEqual(line.pos().y(), -0.6180339887498949)
801+
self.assertEqual(line.angle, 58.28252558853899)
802+
self.assertEqual(line.pen.color(), color)
803+
self.assertEqual(line.pen.width(), width)
804+
805+
# Normal line, negative slope
806+
line = OWScatterPlotGraph._orthonormal_line(
807+
np.array([1, 2, 3]), np.array([3, 2, 1]), color, width)
808+
self.assertEqual(line.pos().x(), 1)
809+
self.assertEqual(line.pos().y(), 3)
810+
self.assertEqual(line.angle % 360, 315)
811+
812+
# Horizontal line
813+
line = OWScatterPlotGraph._orthonormal_line(
814+
np.array([10, 11, 12]), np.array([42, 42, 42]), color, width)
815+
self.assertEqual(line.pos().x(), 10)
816+
self.assertEqual(line.pos().y(), 42)
817+
self.assertEqual(line.angle, 0)
818+
819+
# Vertical line
820+
line = OWScatterPlotGraph._orthonormal_line(
821+
np.array([42, 42, 42]), np.array([10, 11, 12]), color, width)
822+
self.assertEqual(line.pos().x(), 42)
823+
self.assertEqual(line.pos().y(), 10)
824+
self.assertEqual(line.angle, 90)
825+
826+
# No line because all points coincide
827+
line = OWScatterPlotGraph._orthonormal_line(
828+
np.array([1, 1, 1]), np.array([42, 42, 42]), color, width)
829+
self.assertIsNone(line)
830+
831+
# No line because the group is symmetric
832+
line = OWScatterPlotGraph._orthonormal_line(
833+
np.array([1, 1, 2, 2]), np.array([42, 5, 5, 42]), color, width)
834+
self.assertIsNone(line)
835+
836+
def test_regression_line(self):
837+
color = QColor(1, 2, 3)
838+
width = 42
839+
# Normal line
840+
line = OWScatterPlotGraph._regression_line(
841+
np.array([0, 1, 1, 2]), np.array([0, 0, 2, 2]), color, width)
842+
self.assertEqual(line.pos().x(), 0)
843+
self.assertAlmostEqual(line.pos().y(), 0)
844+
self.assertEqual(line.angle, 45)
845+
self.assertEqual(line.pen.color(), color)
846+
self.assertEqual(line.pen.width(), width)
847+
848+
# Normal line, negative slope
849+
line = OWScatterPlotGraph._regression_line(
850+
np.array([1, 2, 3]), np.array([3, 2, 1]), color, width)
851+
self.assertEqual(line.pos().x(), 1)
852+
self.assertEqual(line.pos().y(), 3)
853+
self.assertEqual(line.angle % 360, 315)
854+
855+
# Horizontal line
856+
line = OWScatterPlotGraph._regression_line(
857+
np.array([10, 11, 12]), np.array([42, 42, 42]), color, width)
858+
self.assertEqual(line.pos().x(), 10)
859+
self.assertEqual(line.pos().y(), 42)
860+
self.assertEqual(line.angle, 0)
861+
862+
# Vertical line
863+
line = OWScatterPlotGraph._regression_line(
864+
np.array([42, 42, 42]), np.array([10, 11, 12]), color, width)
865+
self.assertIsNone(line)
866+
867+
# No line because all points coincide
868+
line = OWScatterPlotGraph._regression_line(
869+
np.array([1, 1, 1]), np.array([42, 42, 42]), color, width)
870+
self.assertIsNone(line)
871+
872+
def test_add_line_calls_proper_regressor(self):
873+
graph = self.widget.graph
874+
graph._orthonormal_line = Mock(return_value=None)
875+
graph._regression_line = Mock(return_value=None)
876+
x, y, c, w = Mock(), Mock(), Mock(), Mock()
877+
878+
graph.orthonormal_regression = True
879+
graph._add_line(x, y, c, w)
880+
graph._orthonormal_line.assert_called_once_with(x, y, c, w)
881+
graph._regression_line.assert_not_called()
882+
graph._orthonormal_line.reset_mock()
883+
884+
graph.orthonormal_regression = False
885+
graph._add_line(x, y, c, w)
886+
graph._regression_line.assert_called_with(x, y, c, w)
887+
graph._orthonormal_line.assert_not_called()
888+
889+
def test_no_regression_line(self):
890+
graph = self.widget.graph
891+
graph._orthonormal_line = lambda *_: None
892+
graph.orthonormal_regression = True
893+
894+
graph.plot_widget.addItem = Mock()
895+
896+
x, y, c, w = Mock(), Mock(), Mock(), Mock()
897+
graph._add_line(x, y, c, w)
898+
graph.plot_widget.addItem.assert_not_called()
899+
self.assertEqual(graph.reg_line_items, [])
900+
901+
def test_update_regression_line_calls_add_line(self):
902+
widget = self.widget
903+
graph = widget.graph
904+
x, y = np.array([[0, 0], [1, 0], [1, 2], [2, 2],
905+
[0, 1], [1, 3], [2, 5]], dtype=np.float).T
906+
colors = np.array([0, 0, 0, 0, 1, 1, 1], dtype=np.float)
907+
widget.get_coordinates_data = lambda: (x, y)
908+
widget.get_color_data = lambda: colors
909+
widget.is_continuous_color = lambda: False
910+
graph.palette = DefaultRGBColors
911+
graph.controls.show_reg_line.setChecked(True)
912+
913+
graph._add_line = Mock()
914+
915+
graph.update_regression_line()
916+
(args1, kwargs1), (args2, kwargs2), (args3, kwargs3) = \
917+
graph._add_line.call_args_list
918+
np.testing.assert_equal(args1[0], x)
919+
np.testing.assert_equal(args1[1], y)
920+
self.assertEqual(args1[2], QColor("#505050"))
921+
self.assertEqual(kwargs1["width"], 1)
922+
923+
np.testing.assert_equal(args2[0], x[:4])
924+
np.testing.assert_equal(args2[1], y[:4])
925+
self.assertEqual(args2[2], graph.palette[0])
926+
self.assertEqual(kwargs2["width"], 3)
927+
928+
np.testing.assert_equal(args3[0], x[4:])
929+
np.testing.assert_equal(args3[1], y[4:])
930+
self.assertEqual(args3[2], graph.palette[1])
931+
self.assertEqual(kwargs3["width"], 3)
932+
graph._add_line.reset_mock()
933+
934+
# Continuous color - just a single line
935+
widget.is_continuous_color = lambda: True
936+
graph.update_regression_line()
937+
graph._add_line.assert_called_once()
938+
args1, kwargs1 = graph._add_line.call_args_list[0]
939+
np.testing.assert_equal(args1[0], x)
940+
np.testing.assert_equal(args1[1], y)
941+
self.assertEqual(args1[2], QColor("#505050"))
942+
self.assertEqual(kwargs1["width"], 1)
943+
graph._add_line.reset_mock()
944+
widget.is_continuous_color = lambda: False
945+
946+
# No palette - just a single line
947+
graph.palette = None
948+
graph.update_regression_line()
949+
graph._add_line.assert_called_once()
950+
graph._add_line.reset_mock()
951+
graph.palette = DefaultRGBColors
952+
953+
# Regression line is disabled
954+
graph.show_reg_line = False
955+
graph.update_regression_line()
956+
graph._add_line.assert_not_called()
957+
graph.show_reg_line = True
958+
959+
# No colors - just one line
960+
widget.get_color_data = lambda: None
961+
graph.update_regression_line()
962+
graph._add_line.assert_called_once()
963+
graph._add_line.reset_mock()
964+
965+
# No data
966+
widget.get_coordinates_data = lambda: (None, None)
967+
graph.update_regression_line()
968+
graph._add_line.assert_not_called()
969+
graph.show_reg_line = True
970+
widget.get_coordinates_data = lambda: (x, y)
971+
972+
# One color group contains just one point - skip that line
973+
widget.get_color_data = lambda: np.array([0] + [1] * (len(x) - 1))
974+
975+
graph.update_regression_line()
976+
(args1, kwargs1), (args2, kwargs2) = graph._add_line.call_args_list
977+
np.testing.assert_equal(args1[0], x)
978+
np.testing.assert_equal(args1[1], y)
979+
self.assertEqual(args1[2], QColor("#505050"))
980+
self.assertEqual(kwargs1["width"], 1)
981+
982+
np.testing.assert_equal(args2[0], x[1:])
983+
np.testing.assert_equal(args2[1], y[1:])
984+
self.assertEqual(args2[2], graph.palette[1])
985+
self.assertEqual(kwargs2["width"], 3)
986+
987+
def test_update_regression_line_is_called(self):
988+
widget = self.widget
989+
graph = widget.graph
990+
urline = graph.update_regression_line = Mock()
991+
992+
self.send_signal(widget.Inputs.data, self.data)
993+
urline.assert_called_once()
994+
urline.reset_mock()
995+
996+
self.send_signal(widget.Inputs.data, None)
997+
urline.assert_called_once()
998+
urline.reset_mock()
999+
1000+
self.send_signal(widget.Inputs.data, self.data)
1001+
urline.assert_called_once()
1002+
urline.reset_mock()
1003+
1004+
simulate.combobox_activate_index(self.widget.controls.attr_color, 0)
1005+
urline.assert_called_once()
1006+
urline.reset_mock()
1007+
1008+
simulate.combobox_activate_index(self.widget.controls.attr_color, 2)
1009+
urline.assert_called_once()
1010+
urline.reset_mock()
1011+
1012+
simulate.combobox_activate_index(self.widget.controls.attr_x, 3)
1013+
urline.assert_called_once()
1014+
urline.reset_mock()
7441015

7451016

7461017
if __name__ == "__main__":

0 commit comments

Comments
 (0)