Skip to content

Commit 97a4856

Browse files
Add output_duplicates function
1 parent 8784364 commit 97a4856

File tree

2 files changed

+106
-98
lines changed

2 files changed

+106
-98
lines changed

Orange/widgets/visualize/owvenndiagram.py

Lines changed: 105 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class Error(widget.OWWidget.Error):
5959
instances_mismatch = Msg("Data sets do not contain the same instances.")
6060
too_many_inputs = Msg("Venn diagram accepts at most five datasets.")
6161

62+
class Warning(widget.OWWidget.Warning):
63+
renamed_vars = Msg("Some variables have been renamed "
64+
"to avoid duplicates.\n{}")
65+
6266
selection: list
6367

6468
settingsHandler = settings.DomainContextHandler()
@@ -73,6 +77,8 @@ class Error(widget.OWWidget.Error):
7377

7478
want_control_area = False
7579
graph_name = "scene"
80+
atr_types = ['attributes', 'metas', 'class_vars']
81+
atr_vals = {'metas': 'metas', 'attributes': 'X', 'class_vars': 'Y'}
7682

7783
def __init__(self):
7884
super().__init__()
@@ -347,18 +353,21 @@ def invalidateOutput(self):
347353

348354
def merge_data(self, domain, values):
349355
X, metas, class_vars = None, None, None
356+
renamed = []
350357
for val in domain.values():
351358
names = [var.name for var in val]
352359
unique_names = get_unique_names_duplicates(names)
353-
for n, u, var in zip(names, unique_names, val):
360+
for n, u, idx, var in zip(names, unique_names, range(len(val)), val):
354361
if n != u:
355-
var.name = u
356-
#TODO: warning because of a weird clash?
357-
if values['attributes']:
362+
val[idx] = var.copy(name=u)
363+
renamed.append(n)
364+
if renamed:
365+
self.Warning.renamed_vars(', '.join(renamed))
366+
if 'attributes' in values.keys():
358367
X = np.hstack(values['attributes'])
359-
if values['metas']:
368+
if 'metas' in values.keys():
360369
metas = np.hstack(values['metas'])
361-
if values['class_vars']:
370+
if 'class_vars' in values.keys():
362371
class_vars = np.hstack(values['class_vars'])
363372
return Table.from_numpy(Domain(**domain), X, class_vars, metas)
364373

@@ -380,7 +389,7 @@ def extract_new_table(self, var_dict):
380389
values[atr_type].append(getattr(self.data[var_data[1][0][1]].table[:, var_name], atr_vals[atr_type]).reshape(-1, 1))
381390
return self.merge_data(domain, values)
382391

383-
def curry_merge(self, table_key, atr_type, ids=None):
392+
def curry_merge(self, table_key, atr_type, ids=None, selection=False):
384393
if self.rowwise:
385394
check_equality = self.arrays_equal_rows
386395
else:
@@ -389,23 +398,27 @@ def curry_merge(self, table_key, atr_type, ids=None):
389398
def inner(new_atrs, atr):
390399
"""
391400
Atrs - list of variables we wish to merge
392-
new_atrs - dictionary where key is old name, val
393-
is [is_different:bool, table_keys:list])
401+
new_atrs - dictionary where key is old var, val
402+
is [is_different:bool, table_keys:list]), is_different is set to True,
403+
if we are outputing duplicates, but the value is arbitrary
394404
"""
395405
atr_vals = {'metas': 'metas', 'attributes': 'X', 'class_vars': 'Y'}
396-
if atr.name in new_atrs.keys():
397-
if not new_atrs[atr.name][0]:
398-
for var, key in new_atrs[atr.name][1]:
406+
if atr in new_atrs.keys():
407+
if not selection and self.output_duplicates:
408+
#if output_duplicates, we just check if compute value is the same
409+
new_atrs[atr][0] = True
410+
elif not new_atrs[atr][0]:
411+
for var, key in new_atrs[atr][1]:
399412
if not check_equality(table_key,
400413
key,
401414
atr.name,
402415
atr_vals[atr_type],
403416
type(var), ids):
404-
new_atrs[atr.name][0] = True
417+
new_atrs[atr][0] = True
405418
break
406-
new_atrs[atr.name][1].append((atr, table_key))
419+
new_atrs[atr][1].append((atr, table_key))
407420
else:
408-
new_atrs[atr.name] = [False, [(atr, table_key)]]
421+
new_atrs[atr] = [False, [(atr, table_key)]]
409422
return new_atrs
410423
return inner
411424

@@ -468,7 +481,7 @@ def extract_rowwise(self, var_dict, ids=None, selection=False):
468481
is [is_different:bool, table_keys:list])
469482
ids: dict with ids for each table
470483
"""
471-
all_ids = list(reduce(set.union, [set(val.keys()) for val in ids.values()], set()))
484+
all_ids = sorted(list(reduce(set.union, [set(val.keys()) for val in ids.values()], set())))
472485

473486
permutations = dict()
474487
for table_key, dict_ in ids.items():
@@ -479,8 +492,8 @@ def extract_rowwise(self, var_dict, ids=None, selection=False):
479492
atr_vals = {'metas': 'metas', 'attributes': 'X', 'class_vars': 'Y'}
480493
for atr_type, vars_dict in var_dict.items():
481494
for var_name, var_data in vars_dict.items():
482-
duplicated = var_data[0]
483-
if duplicated:
495+
different = var_data[0]
496+
if different:
484497
#columns are different, copy all, rename them
485498
for var, table_key in var_data[1]:
486499
temp = self.data[table_key].table
@@ -517,41 +530,92 @@ def extract_rowwise(self, var_dict, ids=None, selection=False):
517530

518531
def get_indices(self, table, selection):
519532
"""Returns mappings of ids (be it row id or string) to indices in tables"""
520-
#TODO: refactor?
521533
if self.selected_feature:
522-
items, ids = np.unique(getattr(table[:, self.selected_feature], 'metas'),
523-
return_index=True)
524-
if selection:
525-
return OrderedDict([(item, idx) for item, idx in zip(items, ids)
526-
if item in self.selected_items])
527-
return OrderedDict(zip(items, ids))
534+
if self.output_duplicates and selection:
535+
items, inverse = np.unique(getattr(table[:, self.selected_feature], 'metas'),
536+
return_inverse=True)
537+
ids = [np.nonzero(inverse == idx)[0] for idx in range(len(items))]
538+
else:
539+
items, ids = np.unique(getattr(table[:, self.selected_feature], 'metas'),
540+
return_index=True)
541+
542+
else:
543+
items = table.ids
544+
ids = range(len(table))
545+
528546
if selection:
529-
if not self.selected_items:
530-
return None
531-
return OrderedDict([(idx, val) for val, idx in zip(range(len(table.ids)), table.ids)
532-
if idx in self.selected_items])
533-
return OrderedDict(zip(table.ids, range(len(table))))
534-
535-
def get_indices_to_match_by(self, selected_keys):
536-
selected, annotated = dict(), dict()
537-
for key, val in self.data.items():
538-
annotated[key] = self.get_indices(val.table, None)
539-
if self.selection and key in selected_keys:
540-
selected[key] = self.get_indices(val.table, self.selection)
541-
return selected, annotated
547+
return OrderedDict([(item, idx) for item, idx in zip(items, ids)
548+
if item in self.selected_items])
549+
550+
return OrderedDict(zip(items, ids))
551+
552+
def get_indices_to_match_by(self, relevant_keys, selection=False):
553+
dict_ = dict()
554+
for key in relevant_keys:
555+
table = self.data[key].table
556+
dict_[key] = self.get_indices(table, selection)
557+
return dict_
542558

543559
def create_from_rows(self, relevant_keys, relevant_ids, selection=False):
544560
atr_types = ['attributes', 'metas', 'class_vars']
545561
var_dict = {}
546562
for atr_type in atr_types:
547563
container = {}
548564
for table_key in relevant_keys:
549-
merge_vars = self.curry_merge(table_key, atr_type, relevant_ids)
565+
merge_vars = self.curry_merge(table_key, atr_type, relevant_ids, selection)
550566
atrs = getattr(self.data[table_key].table.domain, atr_type)
551567
container = reduce(merge_vars, atrs, container)
552568
var_dict[atr_type] = container
569+
if self.output_duplicates and not selection:
570+
return self.extract_rowwise_duplicates(var_dict, relevant_ids, relevant_keys)
553571
return self.extract_rowwise(var_dict, relevant_ids, selection)
554572

573+
def make_it_fit(self, a, b):
574+
#TODO: rename function
575+
if a == b.shape:
576+
return b
577+
if a[1] == 1:
578+
return np.atleast_2d(b).T
579+
return np.atleast_2d(b)
580+
581+
def expand_tables(self, table, atrs, metas, cv):
582+
exp = []
583+
for all_el, atr_type in zip([atrs, metas, cv], self.atr_types):
584+
#TODO : pohendlaj manjakoče atr_type & columns
585+
cur_el = getattr(table.domain, atr_type)
586+
perm = get_perm(cur_el, all_el)
587+
array = np.empty((len(table), len(all_el)))
588+
array.fill(np.nan)
589+
b = getattr(table, self.atr_vals[atr_type])
590+
array[:, perm] = self.make_it_fit(array[:, perm].shape, b)
591+
#array[:, perm] = np.atleast_2d(getattr(table, self.atr_vals[atr_type]))
592+
exp.append(array)
593+
#TODO: maybe this could be smarter
594+
return exp[0], exp[1], exp[2]
595+
596+
def extract_rowwise_duplicates(self, var_dict, ids, relevant_keys):
597+
#za vsak id v vsakemu stolpcu rabimo indekse
598+
#extractamo celo podtabelo, vstavimo morebitne manjkajoče stolpce, na koncu vstack
599+
all_ids = sorted(list(reduce(set.union, [set(val.keys()) for val in ids.values()], set())))
600+
sort_key = lambda var: var.name
601+
all_atrs = sorted([var for var in var_dict['attributes'].keys()], key=sort_key)
602+
all_metas = sorted([var for var in var_dict['metas'].keys()], key=sort_key)
603+
all_cv = sorted([var for var in var_dict['class_vars'].keys()], key=sort_key)
604+
605+
all_x, all_y, all_m = [], [], []
606+
for idx in all_ids:
607+
#iterate trough tables with same idx
608+
for table_key in relevant_keys:
609+
map_ = ids[table_key][idx]
610+
extracted = self.data[table_key].table[map_]
611+
x, m, y = self.expand_tables(extracted, all_atrs, all_metas, all_cv)
612+
all_x.append(x)
613+
all_y.append(y)
614+
all_m.append(m)
615+
domain = {'attributes': all_atrs, 'metas': all_metas, 'class_vars': all_cv}
616+
values = {'attributes': [np.vstack(all_x)], 'metas': [np.vstack(all_m)], 'class_vars': [np.vstack(all_y)]}
617+
return self.merge_data(domain, values)
618+
555619
def commit(self):
556620

557621
if not self.vennwidget.vennareas() or not self.data.keys():
@@ -569,7 +633,8 @@ def commit(self):
569633
selected = None
570634

571635
if self.rowwise:
572-
selected_ids, annotated_ids = self.get_indices_to_match_by(selected_keys)
636+
selected_ids = self.get_indices_to_match_by(selected_keys, self.selection)
637+
annotated_ids = self.get_indices_to_match_by(self.data.keys())
573638
annotated = self.create_from_rows(self.data.keys(), annotated_ids, True)
574639
if self.selected_items:
575640
selected = self.create_from_rows(selected_keys, selected_ids, False)

Orange/widgets/visualize/tests/test_owvenndiagram.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -19,63 +19,6 @@
1919
get_perm)
2020
from Orange.tests import test_filename
2121

22-
23-
class TestVennDiagram(unittest.TestCase):
24-
def add_metas(self, table, meta_attrs, meta_data):
25-
domain = Domain(table.domain.attributes,
26-
table.domain.class_vars,
27-
table.domain.metas + meta_attrs)
28-
metas = np.hstack((table.metas, meta_data))
29-
return Table(domain, table.X, table.Y, metas)
30-
31-
"""
32-
def test_venn_diagram(self):
33-
sources = ["SVM Learner", "Naive Bayes", "Random Forest"]
34-
item_id_var = StringVariable("item_id")
35-
source_var = StringVariable("source")
36-
table = Table("zoo")
37-
class_var = table.domain.class_var
38-
cv = np.random.randint(len(class_var.values), size=(3, len(sources)))
39-
40-
tables = []
41-
# pylint: disable=consider-using-enumerate
42-
for i in range(len(sources)):
43-
temp_table = Table.from_table(table.domain, table,
44-
[0 + i, 1 + i, 2 + i])
45-
temp_d = (DiscreteVariable("%s(%s)" % (class_var.name,
46-
sources[0 + i]),
47-
class_var.values),
48-
source_var, item_id_var)
49-
temp_m = np.array([[cv[0, i], sources[i], table.metas[0 + i, 0]],
50-
[cv[1, i], sources[i], table.metas[1 + i, 0]],
51-
[cv[2, i], sources[i], table.metas[2 + i, 0]]],
52-
dtype=object)
53-
temp_table = self.add_metas(temp_table, temp_d, temp_m)
54-
tables.append(temp_table)
55-
56-
data = table_concat(tables)
57-
varying = varying_between(data, item_id_var)
58-
if source_var in varying:
59-
varying.remove(source_var)
60-
data = reshape_wide(data, varying, [item_id_var], [source_var])
61-
data = drop_columns(data, [item_id_var])
62-
63-
result = np.array([[table.metas[0, 0], cv[0, 0], np.nan, np.nan],
64-
[table.metas[1, 0], cv[1, 0], cv[0, 1], np.nan],
65-
[table.metas[2, 0], cv[2, 0], cv[1, 1], cv[0, 2]],
66-
[table.metas[3, 0], np.nan, cv[2, 1], cv[1, 2]],
67-
[table.metas[4, 0], np.nan, np.nan, cv[2, 2]]],
68-
dtype=object)
69-
70-
for i in range(len(result)):
71-
for j in range(len(result[0])):
72-
val = result[i][j]
73-
if isinstance(val, float) and np.isnan(val):
74-
self.assertTrue(np.isnan(data.metas[i][j]))
75-
else:
76-
np.testing.assert_equal(data.metas[i][j], result[i][j])
77-
78-
"""
7922
class TestOWVennDiagram(WidgetTest, WidgetOutputsTestMixin):
8023
@classmethod
8124
def setUpClass(cls):
@@ -170,7 +113,7 @@ def test_multiple_input_over_cols(self):
170113
true_atrs = {'sepal length (2)', 'sepal length (1)'}
171114
self.assertTrue(atrs == true_atrs)
172115

173-
out_domain = annotated.domain
116+
out_domain = annotated.domain.attributes
174117
self.assertTrue(out_domain[0].attributes[selected_atr_name])
175118
self.assertTrue(out_domain[1].attributes[selected_atr_name])
176119
self.assertFalse(out_domain[2].attributes[selected_atr_name])

0 commit comments

Comments
 (0)