Skip to content

Commit b7545b3

Browse files
lanzagarnikicc
authored andcommitted
Merge pull request biolab#2699 from jerneju/spg-scaling-metas
[ENH] Scatter Plot Graph and Scaling can handle metas
1 parent a25ad6a commit b7545b3

File tree

5 files changed

+133
-136
lines changed

5 files changed

+133
-136
lines changed

Orange/widgets/unsupervised/owmds.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ def update_data(self, attr_x, attr_y, reset_view=True):
5555
self.plot_widget.hideAxis(axis)
5656
self.plot_widget.setAspectLocked(True, 1)
5757

58-
def get_size_index(self):
59-
if self.attr_size == "Stress":
60-
return -2
61-
return super().get_size_index()
62-
6358
def compute_sizes(self):
6459
def scale(a):
6560
dmin, dmax = np.nanmin(a), np.nanmax(a)
@@ -69,17 +64,16 @@ def scale(a):
6964
return np.zeros_like(a)
7065

7166
self.master.Information.missing_size.clear()
72-
size_index = self.get_size_index()
73-
if size_index == -1:
67+
if self.attr_size is None:
7468
size_data = np.full((self.n_points,), self.point_width,
7569
dtype=float)
76-
elif size_index == -2:
70+
elif self.attr_size == "Stress":
7771
size_data = scale(stress(self.master.embedding, self.master.effective_matrix))
7872
size_data = self.MinShapeSize + size_data * self.point_width
7973
else:
8074
size_data = \
8175
self.MinShapeSize + \
82-
self.scaled_data[size_index, self.valid_data] * \
76+
self.scaled_data.get_column_view(self.attr_size)[0][self.valid_data] * \
8377
self.point_width
8478
nans = np.isnan(size_data)
8579
if np.any(nans):
@@ -270,11 +264,6 @@ def update_regression_line(self):
270264

271265
def init_attr_values(self):
272266
domain = self.data and len(self.data) and self.data.domain or None
273-
if domain is not None:
274-
domain = Domain(
275-
attributes=domain.attributes,
276-
class_vars=domain.class_vars,
277-
metas=tuple(a for a in domain.metas if a.is_primitive()))
278267
for model in self.models:
279268
model.set_domain(domain)
280269
self.graph.attr_color = self.data.domain.class_var if domain else None
@@ -653,21 +642,12 @@ def _setup_plot(self, new=False):
653642
coords = np.vstack((emb_x, emb_y)).T
654643

655644
data = self.data
656-
657-
primitive_metas = tuple(a for a in data.domain.metas if a.is_primitive())
658-
keys = [k for k, a in enumerate(data.domain.metas) if a.is_primitive()]
659-
data_metas = data.metas[:, keys].astype(float)
660-
661-
attributes = self.data.domain.attributes + (self.variable_x, self.variable_y) + \
662-
primitive_metas
645+
attributes = data.domain.attributes + (self.variable_x, self.variable_y)
663646
domain = Domain(attributes=attributes,
664-
class_vars=self.data.domain.class_vars)
665-
if data_metas is not None:
666-
data_x = (self.data.X, coords, data_metas)
667-
else:
668-
data_x = (self.data.X, coords)
669-
data = Table.from_numpy(domain, X=hstack(data_x),
670-
Y=self.data.Y)
647+
class_vars=data.domain.class_vars,
648+
metas=data.domain.metas)
649+
data = Table.from_numpy(domain, X=hstack((data.X, coords)),
650+
Y=data.Y, metas=data.metas)
671651
subset_data = data[self._subset_mask] if self._subset_mask is not None else None
672652
self.graph.new_data(data, subset_data=subset_data, new=new)
673653
self.graph.update_data(self.variable_x, self.variable_y, True)

Orange/widgets/utils/scaling.py

Lines changed: 70 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from itertools import chain
2+
13
import numpy as np
24

5+
from Orange.data import Domain
36
from Orange.statistics.basic_stats import DomainBasicStats
47
from Orange.widgets.settings import Setting
58
from Orange.widgets.utils import checksum
@@ -12,12 +15,11 @@ class ScaleData:
1215

1316
def _reset_data(self):
1417
self.domain = None
15-
self.data = None
16-
self.original_data = None # as numpy array
18+
self.data = None # as Orange Table
1719
self.scaled_data = None # in [0, 1]
1820
self.jittered_data = None
1921
self.attr_values = {}
20-
self.domain_data_stat = []
22+
self.domain_data_stat = {}
2123
self.valid_data_array = None
2224
self.attribute_flip_info = {} # dictionary with attr: 0/1 if flipped
2325
self.jitter_seed = 0
@@ -30,53 +32,58 @@ def rescale_data(self):
3032

3133
def _compute_domain_data_stat(self):
3234
stt = self.domain_data_stat = \
33-
getCached(self.data, DomainBasicStats, (self.data,))
34-
for index in range(len(self.domain)):
35-
attr = self.domain[index]
35+
getCached(self.data, DomainBasicStats, (self.data, True))
36+
domain = self.domain
37+
for attr in chain(domain.variables, domain.metas):
3638
if attr.is_discrete:
3739
self.attr_values[attr] = [0, len(attr.values)]
3840
elif attr.is_continuous:
39-
self.attr_values[attr] = [stt[index].min, stt[index].max]
41+
self.attr_values[attr] = [stt[attr].min, stt[attr].max]
4042

4143
def _compute_scaled_data(self):
4244
data = self.data
4345
# We cache scaled_data and validArray to share them between widgets
4446
cached = getCached(data, "visualizationData")
4547
if cached:
46-
self.original_data, self.scaled_data, self.valid_data_array = cached
48+
self.data, self.scaled_data, self.valid_data_array = cached
4749
return
4850

4951
Y = data.Y if data.Y.ndim == 2 else np.atleast_2d(data.Y).T
50-
self.original_data = np.hstack((data.X, Y)).T
51-
self.scaled_data = no_jit = self.original_data.copy()
52-
self.valid_data_array = np.isfinite(no_jit)
53-
for index in range(len(data.domain)):
54-
attr = data.domain[index]
52+
if np.any(data.metas):
53+
all_data = (data.X, Y, data.metas)
54+
else:
55+
all_data = (data.X, Y)
56+
all_data = np.hstack(all_data).T
57+
self.scaled_data = self.data.copy()
58+
self.valid_data_array = np.isfinite(all_data)
59+
domain = self.domain
60+
for attr in chain(domain.attributes, domain.class_vars, domain.metas):
61+
c = self.scaled_data.get_column_view(attr)[0]
5562
if attr.is_discrete:
56-
no_jit[index] *= 2
57-
no_jit[index] += 1
58-
no_jit[index] /= 2 * len(attr.values)
63+
c += 0.5
64+
c /= len(attr.values)
5965
else:
60-
dstat = self.domain_data_stat[index]
61-
no_jit[index] -= dstat.min
66+
dstat = self.domain_data_stat[attr]
67+
c -= dstat.min
6268
if dstat.max != dstat.min:
63-
no_jit[index] /= dstat.max - dstat.min
69+
c /= dstat.max - dstat.min
6470
setCached(data, "visualizationData",
65-
(self.original_data, self.scaled_data, self.valid_data_array))
71+
(self.data, self.scaled_data, self.valid_data_array))
6672

6773
def _compute_jittered_data(self):
6874
data = self.data
6975
self.jittered_data = self.scaled_data.copy()
7076
random = np.random.RandomState(seed=self.jitter_seed)
71-
for index, col in enumerate(self.jittered_data):
77+
domain = self.domain
78+
for attr in chain(domain.variables, domain.metas):
7279
# Need to use a different seed for each feature
73-
attr = data.domain[index]
7480
if attr.is_discrete:
7581
off = self.jitter_size / (25 * max(1, len(attr.values)))
7682
elif attr.is_continuous and self.jitter_continuous:
7783
off = self.jitter_size / 25
7884
else:
7985
continue
86+
col = self.jittered_data.get_column_view(attr)[0]
8087
col += random.uniform(-off, off, len(data))
8188
# fix values outside [0, 1]
8289
col = np.absolute(col)
@@ -92,8 +99,13 @@ def set_data(self, data, skip_if_same=False, no_data=False):
9299
if data is None:
93100
return
94101

95-
self.domain = data.domain
96-
self.data = data
102+
domain = data.domain
103+
new_domain = Domain(attributes=domain.attributes,
104+
class_vars=domain.class_vars,
105+
metas=tuple(v for v in domain.metas if v.is_primitive()))
106+
self.data = data.transform(new_domain)
107+
self.data.metas = self.data.metas.astype(float)
108+
self.domain = self.data.domain
97109
self.attribute_flip_info = {}
98110
if not no_data:
99111
self._compute_domain_data_stat()
@@ -103,67 +115,73 @@ def set_data(self, data, skip_if_same=False, no_data=False):
103115
def flip_attribute(self, attr):
104116
if attr.is_discrete:
105117
return 0
106-
index = self.domain.index(attr)
107118
self.attribute_flip_info[attr] = 1 - self.attribute_flip_info.get(attr, 0)
108119
if attr.is_continuous:
109120
self.attr_values[attr] = [-self.attr_values[attr][1],
110121
-self.attr_values[attr][0]]
111-
112-
self.jittered_data[index] = 1 - self.jittered_data[index]
113-
self.scaled_data[index] = 1 - self.scaled_data[index]
122+
col = self.jittered_data.get_column_view(attr)[0]
123+
col *= -1
124+
col += 1
125+
col = self.scaled_data.get_column_view(attr)[0]
126+
col *= -1
127+
col += 1
114128
return 1
115129

116-
def get_valid_list(self, indices):
130+
def get_valid_list(self, attrs):
117131
"""
118132
Get array of 0 and 1 of len = len(self.data). If there is a missing
119133
value at any attribute in indices return 0 for that instance.
120134
"""
121135
if self.valid_data_array is None or len(self.valid_data_array) == 0:
122136
return np.array([], np.bool)
137+
domain = self.domain
138+
indices = []
139+
for index, attr in enumerate(chain(domain.variables, domain.metas)):
140+
if attr in attrs:
141+
indices.append(index)
123142
return np.all(self.valid_data_array[indices], axis=0)
124143

125-
def get_valid_indices(self, indices):
144+
def get_valid_indices(self, attrs):
126145
"""
127146
Get array with numbers that represent the instance indices that have a
128147
valid data value.
129148
"""
130-
valid_list = self.get_valid_list(indices)
149+
valid_list = self.get_valid_list(attrs)
131150
return np.nonzero(valid_list)[0]
132151

133152

134153
class ScaleScatterPlotData(ScaleData):
135-
def get_xy_data_positions(self, xattr, yattr, filter_valid=False,
154+
def get_xy_data_positions(self, attr_x, attr_y, filter_valid=False,
136155
copy=True):
137156
"""
138157
Create x-y projection of attributes in attrlist.
139158
140159
"""
141-
xattr_index = self.domain.index(xattr)
142-
yattr_index = self.domain.index(yattr)
160+
jit = self.jittered_data
143161
if filter_valid is True:
144-
filter_valid = self.get_valid_list([xattr_index, yattr_index])
162+
filter_valid = self.get_valid_list([attr_x, attr_y])
145163
if isinstance(filter_valid, np.ndarray):
146-
xdata = self.jittered_data[xattr_index, filter_valid]
147-
ydata = self.jittered_data[yattr_index, filter_valid]
164+
data_x = jit.get_column_view(attr_x)[0][filter_valid]
165+
data_y = jit.get_column_view(attr_y)[0][filter_valid]
148166
elif copy:
149-
xdata = self.jittered_data[xattr_index].copy()
150-
ydata = self.jittered_data[yattr_index].copy()
167+
data_x = jit.get_column_view(attr_x)[0].copy()
168+
data_y = jit.get_column_view(attr_y)[0].copy()
151169
else:
152-
xdata = self.jittered_data[xattr_index]
153-
ydata = self.jittered_data[yattr_index]
170+
data_x = jit.get_column_view(attr_x)[0]
171+
data_y = jit.get_column_view(attr_y)[0]
154172

155-
if self.domain[xattr_index].is_discrete:
156-
xdata *= len(self.domain[xattr_index].values)
157-
xdata -= 0.5
173+
if attr_x.is_discrete:
174+
data_x *= len(attr_x.values)
175+
data_x -= 0.5
158176
else:
159-
xdata *= self.attr_values[xattr][1] - self.attr_values[xattr][0]
160-
xdata += float(self.attr_values[xattr][0])
161-
if self.domain[yattr_index].is_discrete:
162-
ydata *= len(self.domain[yattr_index].values)
163-
ydata -= 0.5
177+
data_x *= self.attr_values[attr_x][1] - self.attr_values[attr_x][0]
178+
data_x += float(self.attr_values[attr_x][0])
179+
if attr_y.is_discrete:
180+
data_y *= len(attr_y.values)
181+
data_y -= 0.5
164182
else:
165-
ydata *= self.attr_values[yattr][1] - self.attr_values[yattr][0]
166-
ydata += float(self.attr_values[yattr][0])
167-
return xdata, ydata
183+
data_y *= self.attr_values[attr_y][1] - self.attr_values[attr_y][0]
184+
data_y += float(self.attr_values[attr_y][0])
185+
return data_x, data_y
168186

169187
getXYDataPositions = get_xy_data_positions

Orange/widgets/visualize/owscatterplot.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,12 @@ def iterate_states(self, initial_state):
4747

4848
def compute_score(self, state):
4949
graph = self.master.graph
50-
ind12 = [graph.domain.index(self.attrs[x]) for x in state]
51-
valid = graph.get_valid_list(ind12)
52-
X = graph.jittered_data[ind12, :][:, valid].T
50+
attrs = [self.attrs[x] for x in state]
51+
valid = graph.get_valid_list(attrs)
52+
cols = []
53+
for var in attrs:
54+
cols.append(graph.jittered_data.get_column_view(var)[0][valid])
55+
X = np.column_stack(cols)
5356
Y = self.master.data.Y[valid]
5457
if X.shape[0] < self.minK:
5558
return
@@ -66,7 +69,7 @@ def bar_length(self, score):
6669
return max(0, -score)
6770

6871
def score_heuristic(self):
69-
X = self.master.graph.jittered_data.T
72+
X = self.master.graph.jittered_data.X
7073
Y = self.master.data.Y
7174
mdomain = self.master.data.domain
7275
dom = Domain([ContinuousVariable(str(i)) for i in range(X.shape[1])],
@@ -139,7 +142,6 @@ def __init__(self):
139142

140143
self.data = None # Orange.data.Table
141144
self.subset_data = None # Orange.data.Table
142-
self.data_metas_X = None # self.data, where primitive metas are moved to X
143145
self.sql_data = None # Orange.data.sql.table.SqlTable
144146
self.attribute_selection_list = None # list of Orange.data.Variable
145147
self.__timer = QTimer(self, interval=1200)
@@ -243,7 +245,6 @@ def set_data(self, data):
243245
same_domain = (self.data and data and
244246
data.domain.checksum() == self.data.domain.checksum())
245247
self.data = data
246-
self.data_metas_X = self.move_primitive_metas_to_X(data)
247248

248249
if not same_domain:
249250
self.init_attr_values()
@@ -295,7 +296,6 @@ def add_data(self, time=0.4):
295296
data_sample.download_data(2000, partial=True)
296297
data = Table(data_sample)
297298
self.data = Table.concatenate((self.data, data), axis=0)
298-
self.data_metas_X = self.move_primitive_metas_to_X(self.data)
299299
self.handleNewSignals()
300300

301301
def switch_sampling(self):
@@ -304,15 +304,6 @@ def switch_sampling(self):
304304
self.add_data()
305305
self.__timer.start()
306306

307-
def move_primitive_metas_to_X(self, data):
308-
if data is not None:
309-
new_attrs = [a for a in data.domain.attributes + data.domain.metas
310-
if a.is_primitive()]
311-
new_metas = [m for m in data.domain.metas if not m.is_primitive()]
312-
new_domain = Domain(new_attrs, data.domain.class_vars, new_metas)
313-
data = data.transform(new_domain)
314-
return data
315-
316307
@Inputs.data_subset
317308
def set_subset_data(self, subset_data):
318309
self.warning()
@@ -322,12 +313,11 @@ def set_subset_data(self, subset_data):
322313
else:
323314
self.warning("Data subset does not support large Sql tables")
324315
subset_data = None
325-
self.subset_data = self.move_primitive_metas_to_X(subset_data)
326316
self.controls.graph.alpha_value.setEnabled(subset_data is None)
327317

328318
# called when all signals are received, so the graph is updated only once
329319
def handleNewSignals(self):
330-
self.graph.new_data(self.data_metas_X, self.subset_data)
320+
self.graph.new_data(self.data, self.subset_data)
331321
if self.attribute_selection_list and self.graph.domain and \
332322
all(attr in self.graph.domain
333323
for attr in self.attribute_selection_list):

0 commit comments

Comments
 (0)