Skip to content

Commit 8d882bf

Browse files
authored
Merge pull request #2335 from jerneju/attribute-mds
[FIX] MDS: Support distances without domain information
2 parents 6cd35ad + a23781c commit 8d882bf

File tree

2 files changed

+63
-67
lines changed

2 files changed

+63
-67
lines changed

Orange/widgets/unsupervised/owmds.py

Lines changed: 46 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import sys
2-
import warnings
32

43
from xml.sax.saxutils import escape
54
from itertools import chain
@@ -36,11 +35,11 @@
3635
ANNOTATED_DATA_SIGNAL_NAME)
3736

3837

39-
def stress(X, D):
40-
assert X.shape[0] == D.shape[0] == D.shape[1]
38+
def stress(X, distD):
39+
assert X.shape[0] == distD.shape[0] == distD.shape[1]
4140
D1_c = scipy.spatial.distance.pdist(X, metric="euclidean")
4241
D1 = scipy.spatial.distance.squareform(D1_c, checks=False)
43-
delta = D1 - D
42+
delta = D1 - distD
4443
delta_sq = numpy.square(delta, out=delta)
4544
return delta_sq.sum(axis=0) / 2
4645

@@ -54,9 +53,6 @@ def make_pen(color, width=1.5, style=Qt.SolidLine, cosmetic=False):
5453
class ScatterPlotItem(pg.ScatterPlotItem):
5554
Symbols = pyqtgraph.graphicsItems.ScatterPlotItem.Symbols
5655

57-
def __init__(self, *args, **kwargs):
58-
super().__init__(*args, **kwargs)
59-
6056
def paint(self, painter, option, widget=None):
6157
if self.opts["pxMode"]:
6258
painter.setRenderHint(QPainter.SmoothPixmapTransform, True)
@@ -457,15 +453,15 @@ def _clear_plot(self):
457453
self._legend_item = None
458454

459455
def update_controls(self):
460-
if self.data is None and getattr(self.matrix, 'axis', 1) == 0:
461-
# Column-wise distances
462-
attr = "Attribute names"
456+
if self.data is None:
457+
axis = getattr(self.matrix, 'axis', 1)
458+
attr = ["Column labels", "Row labels"][axis]
463459
self.labelvar_model[:] = ["No labels", attr]
464-
self.shapevar_model[:] = ["Same shape", attr]
465-
self.colorvar_model[:] = ["Same solor", attr]
460+
self.shapevar_model[:] = ["Same shape"]
461+
self.colorvar_model[:] = ["Same color"]
466462

467-
self.color_value = attr
468-
self.shape_value = attr
463+
self.color_value = "Same color"
464+
self.shape_value = "Same shape"
469465
else:
470466
domain = self.data.domain
471467
all_vars = list(filter_visible(domain.variables + domain.metas))
@@ -642,7 +638,7 @@ def __next_step(self):
642638
loop = self.__update_loop
643639
self.Error.out_of_memory.clear()
644640
try:
645-
embedding, stress, progress = next(self.__update_loop)
641+
embedding, _, progress = next(self.__update_loop)
646642
assert self.__update_loop is loop
647643
except StopIteration:
648644
self.__set_update_loop(None)
@@ -750,7 +746,7 @@ def _update_plot(self):
750746
def _setup_plot(self):
751747
have_data = self.data is not None
752748
have_matrix_transposed = self.matrix is not None and not self.matrix.axis
753-
plotstyle = mdsplotutils.plotstyle
749+
plotstyle = Mdsplotutils.plotstyle
754750

755751
size = self._effective_matrix.shape[0]
756752

@@ -761,7 +757,9 @@ def column(data, variable, dtype=None):
761757
return a.ravel()
762758

763759
def attributes(matrix):
764-
return matrix.row_items.domain.attributes
760+
if matrix.row_items and matrix.row_items.domain:
761+
return matrix.row_items.domain.attributes
762+
return [x + 1 for x in range(len(matrix))]
765763

766764
def scale(a):
767765
dmin, dmax = numpy.nanmin(a), numpy.nanmax(a)
@@ -785,7 +783,7 @@ def jitter(x, factor=1, rstate=None):
785783
if self._selection_mask is not None:
786784
pointflags = numpy.where(
787785
self._selection_mask,
788-
mdsplotutils.Selected, mdsplotutils.NoFlags)
786+
Mdsplotutils.Selected, Mdsplotutils.NoFlags)
789787
else:
790788
pointflags = None
791789

@@ -800,35 +798,23 @@ def jitter(x, factor=1, rstate=None):
800798
else:
801799
palette = None
802800

803-
color_data = mdsplotutils.color_data(
801+
color_data = Mdsplotutils.color_data(
804802
self.data, color_var, plotstyle=plotstyle)
805803
color_data = numpy.hstack(
806804
(color_data,
807805
numpy.full((len(color_data), 1), self.symbol_opacity,
808806
dtype=float))
809807
)
810-
pen_data = mdsplotutils.pen_data(color_data * 0.8, pointflags)
811-
brush_data = mdsplotutils.brush_data(color_data)
812-
elif have_matrix_transposed and \
813-
self.colorvar_model[color_index] == 'Attribute names':
814-
attr = attributes(self.matrix)
815-
palette = colorpalette.ColorPaletteGenerator(len(attr))
816-
color_data = [palette.getRGB(i) for i in range(len(attr))]
817-
color_data = numpy.hstack((
818-
color_data,
819-
numpy.full((len(color_data), 1), self.symbol_opacity,
820-
dtype=float))
821-
)
822-
pen_data = mdsplotutils.pen_data(color_data * 0.8, pointflags)
823-
brush_data = mdsplotutils.brush_data(color_data)
808+
pen_data = Mdsplotutils.pen_data(color_data * 0.8, pointflags)
809+
brush_data = Mdsplotutils.brush_data(color_data)
824810
else:
825811
pen_data = make_pen(QColor(Qt.darkGray), cosmetic=True)
826812
if self._selection_mask is not None:
827813
pen_data = numpy.array(
828814
[pen_data, plotstyle.selected_pen])
829815
pen_data = pen_data[self._selection_mask.astype(int)]
830816
else:
831-
pen_data = numpy.full(self._effective_matrix.dim, pen_data,
817+
pen_data = numpy.full(self._effective_matrix.shape[0], pen_data,
832818
dtype=object)
833819
brush_data = numpy.full(
834820
size, pg.mkColor((192, 192, 192, self.symbol_opacity)),
@@ -853,13 +839,6 @@ def jitter(x, factor=1, rstate=None):
853839
data = data % (len(Symbols) - 1)
854840
data[numpy.isnan(data)] = len(Symbols) - 1
855841
shape_data = symbols[data.astype(int)]
856-
elif have_matrix_transposed and \
857-
self.shapevar_model[shape_index] == 'Attribute names':
858-
Symbols = ScatterPlotItem.Symbols
859-
symbols = numpy.array(list(Symbols.keys()))
860-
attr = [i % (len(Symbols) - 1)
861-
for i, _ in enumerate(attributes(self.matrix))]
862-
shape_data = symbols[attr]
863842
else:
864843
shape_data = "o"
865844
self._shape_data = shape_data
@@ -891,10 +870,10 @@ def jitter(x, factor=1, rstate=None):
891870
label_data = [label_var.str_val(val) for val in label_data]
892871
label_items = [pg.TextItem(text, anchor=(0.5, 0), color=0.0)
893872
for text in label_data]
894-
elif have_matrix_transposed and \
895-
self.labelvar_model[label_index] == 'Attribute names':
873+
elif self.matrix is not None and label_index:
896874
attr = attributes(self.matrix)
897-
label_items = [pg.TextItem(str(text), anchor=(0.5, 0))
875+
label_items = [pg.TextItem(str(text), anchor=(0.5, 0),
876+
color=0.0)
898877
for text in attr]
899878
else:
900879
label_items = None
@@ -986,7 +965,7 @@ def jitter(x, factor=1, rstate=None):
986965
if shape_var is not None or \
987966
(color_var is not None and color_var.is_discrete):
988967

989-
legend_data = mdsplotutils.legend_data(
968+
legend_data = Mdsplotutils.legend_data(
990969
color_var, shape_var, plotstyle=plotstyle)
991970

992971
for color, symbol, text in legend_data:
@@ -1148,18 +1127,18 @@ def scaled(a):
11481127
from Orange.widgets.visualize.owlinearprojection import plotutils
11491128

11501129

1151-
class namespace(namespace):
1130+
class Namespace(namespace):
11521131
def updated(self, **kwargs):
11531132
ns = self.__dict__.copy()
11541133
ns.update(**kwargs)
1155-
return namespace(**ns)
1134+
return Namespace(**ns)
11561135

11571136

1158-
class mdsplotutils(plotutils):
1137+
class Mdsplotutils(plotutils):
11591138
NoFlags, Selected, Highlight = 0, 1, 2
11601139
NoFill, Filled = 0, 1
11611140

1162-
plotstyle = namespace(
1141+
plotstyle = Namespace(
11631142
selected_pen=make_pen(Qt.yellow, width=3, cosmetic=True),
11641143
highligh_pen=QPen(Qt.blue, 1),
11651144
selected_brush=None,
@@ -1194,13 +1173,13 @@ def color_data(table, var=None, mask=None, plotstyle=None):
11941173
N = numpy.count_nonzero(mask)
11951174

11961175
if plotstyle is None:
1197-
plotstyle = mdsplotutils.plotstyle
1176+
plotstyle = Mdsplotutils.plotstyle
11981177

11991178
if var is None:
12001179
col = numpy.zeros(N, dtype=float)
12011180
color_data = numpy.full(N, plotstyle.default_color, dtype=object)
12021181
elif var.is_primitive():
1203-
col = mdsplotutils.column_data(table, var, mask)
1182+
col = Mdsplotutils.column_data(table, var, mask)
12041183
if var.is_discrete:
12051184
palette = plotstyle.discrete_palette
12061185
if len(var.values) >= palette.number_of_colors:
@@ -1219,21 +1198,21 @@ def color_data(table, var=None, mask=None, plotstyle=None):
12191198
@staticmethod
12201199
def pen_data(basecolors, flags=None, plotstyle=None):
12211200
if plotstyle is None:
1222-
plotstyle = mdsplotutils.plotstyle
1201+
plotstyle = Mdsplotutils.plotstyle
12231202

12241203
pens = numpy.array(
1225-
[mdsplotutils.make_pen(QColor(*rgba), width=1)
1204+
[Mdsplotutils.make_pen(QColor(*rgba), width=1)
12261205
for rgba in basecolors],
12271206
dtype=object)
12281207

12291208
if flags is None:
12301209
return pens
12311210

1232-
selected_mask = flags & mdsplotutils.Selected
1211+
selected_mask = flags & Mdsplotutils.Selected
12331212
if numpy.any(selected_mask):
12341213
pens[selected_mask.astype(bool)] = plotstyle.selected_pen
12351214

1236-
highlight_mask = flags & mdsplotutils.Highlight
1215+
highlight_mask = flags & Mdsplotutils.Highlight
12371216
if numpy.any(highlight_mask):
12381217
pens[highlight_mask.astype(bool)] = plotstyle.hightlight_pen
12391218

@@ -1242,17 +1221,17 @@ def pen_data(basecolors, flags=None, plotstyle=None):
12421221
@staticmethod
12431222
def brush_data(basecolors, flags=None, plotstyle=None):
12441223
if plotstyle is None:
1245-
plotstyle = mdsplotutils.plotstyle
1224+
plotstyle = Mdsplotutils.plotstyle
12461225

12471226
brush = numpy.array(
1248-
[mdsplotutils.make_brush(QColor(*c))
1227+
[Mdsplotutils.make_brush(QColor(*c))
12491228
for c in basecolors],
12501229
dtype=object)
12511230

12521231
if flags is None:
12531232
return brush
12541233

1255-
fill_mask = flags & mdsplotutils.Filled
1234+
fill_mask = flags & Mdsplotutils.Filled
12561235

12571236
if not numpy.all(fill_mask):
12581237
brush[~fill_mask] = QBrush(Qt.NoBrush)
@@ -1261,7 +1240,7 @@ def brush_data(basecolors, flags=None, plotstyle=None):
12611240
@staticmethod
12621241
def shape_data(table, var, mask=None, plotstyle=None):
12631242
if plotstyle is None:
1264-
plotstyle = mdsplotutils.plotstyle
1243+
plotstyle = Mdsplotutils.plotstyle
12651244

12661245
N = len(table)
12671246
if mask is not None:
@@ -1271,7 +1250,7 @@ def shape_data(table, var, mask=None, plotstyle=None):
12711250
if var is None:
12721251
return numpy.full(N, "o", dtype=object)
12731252
elif var.is_discrete:
1274-
shape_data = mdsplotutils.column_data(table, var, mask)
1253+
shape_data = Mdsplotutils.column_data(table, var, mask)
12751254
maxsymbols = len(plotstyle.symbols) - 1
12761255
validmask = numpy.isfinite(shape_data)
12771256
shape = shape_data % (maxsymbols - 1)
@@ -1289,7 +1268,7 @@ def shape_data(table, var, mask=None, plotstyle=None):
12891268
@staticmethod
12901269
def size_data(table, var, mask=None, plotstyle=None):
12911270
if plotstyle is None:
1292-
plotstyle = mdsplotutils.plotstyle
1271+
plotstyle = Mdsplotutils.plotstyle
12931272

12941273
N = len(table)
12951274
if mask is not None:
@@ -1299,8 +1278,8 @@ def size_data(table, var, mask=None, plotstyle=None):
12991278
if var is None:
13001279
return numpy.full(N, plotstyle.point_size, dtype=float)
13011280
else:
1302-
size_data = mdsplotutils.column_data(table, var, mask)
1303-
size_data = mdsplotutils.normalized(size_data)
1281+
size_data = Mdsplotutils.column_data(table, var, mask)
1282+
size_data = Mdsplotutils.normalized(size_data)
13041283
size_mask = numpy.isnan(size_data)
13051284
size_data = size_data * plotstyle.point_size + \
13061285
plotstyle.min_point_size
@@ -1314,7 +1293,7 @@ def size_data(table, var, mask=None, plotstyle=None):
13141293
@staticmethod
13151294
def legend_data(color_var=None, shape_var=None, plotstyle=None):
13161295
if plotstyle is None:
1317-
plotstyle = mdsplotutils.plotstyle
1296+
plotstyle = Mdsplotutils.plotstyle
13181297

13191298
if color_var is not None and not color_var.is_discrete:
13201299
color_var = None
@@ -1359,7 +1338,9 @@ def make_brush(color, ):
13591338
return QBrush(color, )
13601339

13611340

1362-
def main_test(argv=sys.argv):
1341+
def main_test(argv=None):
1342+
if argv is None:
1343+
argv = sys.argv
13631344
import gc
13641345
app = QApplication(list(argv))
13651346
argv = app.arguments()

Orange/widgets/unsupervised/tests/test_owmds.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
import numpy as np
77

8-
from AnyQt.QtCore import QEvent
9-
108
from Orange.distance import Euclidean
119
from Orange.widgets.unsupervised.owmds import OWMDS
1210
from Orange.widgets.tests.base import WidgetTest, WidgetOutputsTestMixin, datasets
@@ -88,3 +86,20 @@ def test_other_error(self):
8886
hook.assert_not_called()
8987
self.assertTrue(self.widget.Error.optimization_error.is_shown())
9088

89+
def test_distances_without_data_0(self):
90+
"""
91+
Only distances and no data.
92+
GH-2335
93+
"""
94+
signal_data = Euclidean(self.data, axis=0)
95+
signal_data.row_items = None
96+
self.send_signal("Distances", signal_data)
97+
98+
def test_distances_without_data_1(self):
99+
"""
100+
Only distances and no data.
101+
GH-2335
102+
"""
103+
signal_data = Euclidean(self.data, axis=1)
104+
signal_data.row_items = None
105+
self.send_signal("Distances", signal_data)

0 commit comments

Comments
 (0)