Skip to content

Commit 9a15d0d

Browse files
Add tests
1 parent 53a5638 commit 9a15d0d

File tree

2 files changed

+160
-58
lines changed

2 files changed

+160
-58
lines changed

Orange/widgets/visualize/owvenndiagram.py

Lines changed: 53 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class Outputs:
5959
class Error(widget.OWWidget.Error):
6060
domain_mismatch = Msg("Input data domains do not match.")
6161
instances_mismatch = Msg("Data sets do not contain the same instances.")
62+
too_much_inputs = Msg("venn diagram accepts at most five datasets.")
6263

6364
selection: list
6465

@@ -93,7 +94,7 @@ def __init__(self):
9394
self.infolabel = gui.widgetLabel(box, "No data on input.\n")
9495

9596
self.elementsBox = gui.radioButtonsInBox(
96-
self.controlArea, self, 'usecols', [],
97+
self.controlArea, self, 'usecols', [],
9798
box="Elements", callback=self._on_elements_changed
9899
)
99100

@@ -169,10 +170,9 @@ def __init__(self):
169170
@Inputs.data
170171
@check_sql_input
171172
def setData(self, data, key=None):
172-
self.error()
173+
self.Error.too_much_inputs.clear()
173174
if not self._inputUpdate:
174175
self._inputUpdate = True
175-
176176
if key in self.data:
177177
if data is None:
178178
# Remove the input
@@ -187,7 +187,7 @@ def setData(self, data, key=None):
187187
# TODO: Allow setting more them 5 inputs and let the user
188188
# select the 5 to display.
189189
if len(self.data) == 5:
190-
self.error("Venn diagram accepts at most five datasets.")
190+
self.Error.too_much_inputs()
191191
return
192192
# Add a new input
193193
self.data[key] = _InputData(key, data.name, data)
@@ -253,7 +253,7 @@ def handleNewSignals(self):
253253
self._updateInfo()
254254
super().handleNewSignals()
255255

256-
def itemsetAttr(self, key):
256+
def itemsetAttr(self):
257257
model = self.cb.model()
258258
attr_index = self.cb.currentIndex()
259259
if attr_index >= 0:
@@ -267,23 +267,18 @@ def _controlAtIndex(self, index):
267267
return group_box, combo
268268

269269
def intersectionStringAttrs(self):
270-
atrs = None
271-
for data_ in self.data.values():
272-
if atrs == None:
273-
atrs = set(string_attributes(data_.table.domain))
274-
else:
275-
atrs = atrs.intersection(set(string_attributes(data_.table.domain)))
276-
if atrs is not None:
277-
return list(atrs)
278-
return atrs
270+
sets = [set(string_attributes(data_.table.domain)) for data_ in self.data.values()]
271+
if sets:
272+
return reduce(set.intersection, sets)
273+
return set()
279274

280275
def _setInterAttributes(self):
281276
#finds intersection of string attributes for identifiers checkbox
282277
box = self.inputsBox.layout().itemAt(0).widget()
283278
combo = box.combo_box
284279
model = combo.model()
285280
atrs = self.intersectionStringAttrs()
286-
if (atrs is None) or (len(atrs) == 0):
281+
if not atrs:
287282
self.useidentifiers = False
288283
model[:] = []
289284
box.setEnabled(False)
@@ -297,24 +292,24 @@ def _itemsForInput(self, key):
297292
"""
298293
useidentifiers = self.useidentifiers or not self.samedomain
299294

300-
def items_by_key(key, input):
295+
def items_by_key(table):
301296
model = self.cb.model()
302297
attr_index = self.cb.currentIndex()
303298
if attr_index >= 0:
304299
attr = model[attr_index]
305-
return [str(inst[attr]) for inst in input.table
300+
return [str(inst[attr]) for inst in table
306301
if not numpy.isnan(inst[attr])]
307302
else:
308303
return []
309-
def items_by_eq(key, input):
310304

311-
return list(map(ComparableInstance, input.table))
305+
def items_by_eq(table):
306+
return list(map(ComparableInstance, table))
312307

313-
input = self.data[key]
308+
table = self.data[key].table
314309
if useidentifiers:
315-
items = items_by_key(key, input)
310+
items = items_by_key(table)
316311
else:
317-
items = items_by_eq(key, input)
312+
items = items_by_eq(table)
318313
return items
319314

320315
def _updateItemsets(self):
@@ -335,12 +330,12 @@ def _createItemsets(self):
335330
olditemsets = dict(self.itemsets)
336331
self.itemsets.clear()
337332

338-
for key, input in self.data.items():
333+
for key, input_ in self.data.items():
339334
if self.usecols:
340-
items = [el.name for el in input.table.domain.attributes]
335+
items = [el.name for el in input_.table.domain.attributes]
341336
else:
342337
items = self._itemsForInput(key)
343-
name = input.name
338+
name = input_.name
344339
if key in olditemsets and olditemsets[key].name == name:
345340
# Reuse the title (which might have been changed by the user)
346341
title = olditemsets[key].title
@@ -412,7 +407,7 @@ def _updateInfo(self):
412407
# Clear all warnings
413408
self.warning()
414409

415-
if not len(self.data):
410+
if not self.data:
416411
self.infolabel.setText("No data on input\n")
417412
else:
418413
self.infolabel.setText(f"{len(self.data)} datasets on input\n")
@@ -472,25 +467,6 @@ def _on_itemTextEdited(self, index, text):
472467
def invalidateOutput(self):
473468
self.commit()
474469

475-
def arrays_equal(self, a, b, type_):
476-
"""
477-
checks if arrays have nans in same places and if not-nan elements
478-
are equal
479-
"""
480-
if a is None and b is None:
481-
return True
482-
if a is None or b is None:
483-
return False
484-
if type_ is not StringVariable:
485-
if not numpy.all(numpy.argwhere(numpy.isnan(a)) == numpy.argwhere(numpy.isnan(b))):
486-
return False
487-
if not numpy.any(a[numpy.logical_not(numpy.isnan(a))] == b[numpy.logical_not(numpy.isnan(b))]):
488-
return False
489-
return True
490-
else:
491-
if not(a == b).all():
492-
return False
493-
return True
494470

495471
def extract_new_table(self, var_dict):
496472
domain = defaultdict(lambda: [])
@@ -501,7 +477,7 @@ def extract_new_table(self, var_dict):
501477
if var_data[0]:
502478
#columns are different, copy all, rename them
503479
for var, table_key in var_data[1]:
504-
domain[atr_type].append(var.make('{}_{}{}'.format(var_name, self.data[table_key].table.name, table_key[0])))
480+
domain[atr_type].append(var.make('{}-{}{}'.format(var_name, self.data[table_key].table.name, table_key[0])))
505481
values[atr_type].append(getattr(self.data[table_key].table[:, var_name], atr_vals[atr_type]).reshape(-1, 1))
506482
else:
507483
domain[atr_type].append(deepcopy(var_data[1][0][0]))
@@ -518,10 +494,8 @@ def extract_new_table(self, var_dict):
518494

519495
def create_from_columns(self, columns):
520496
"""
521-
If venn diagram is over columns, columns are selected from first dataset that has them.
522-
Annotated data retains all columns from all datasets and adds an attribute to features,
523-
indicating wether it was selected. Columns are duplicated only if values differ (even
524-
if only in order of values), origin table name is added to column name.
497+
Columns are duplicated only if values differ (even
498+
if only in order of values), origin table name and input slot is added to column name.
525499
"""
526500
selected = None
527501
atr_vals = {'metas': 'metas', 'attributes': 'X', 'class_vars': 'Y'}
@@ -536,9 +510,11 @@ def merge_vars(new_atrs, atr):
536510
if atr.name in new_atrs.keys():
537511
if not new_atrs[atr.name][0]:
538512
for var, key in new_atrs[atr.name][1]:
539-
if not self.arrays_equal(
540-
getattr(self.data[table_key].table[:, atr.name], atr_vals[atr_type]),
541-
getattr(self.data[key].table[:, atr.name], atr_vals[atr_type]),
513+
if not arrays_equal(
514+
getattr(self.data[table_key].table[:, atr.name],
515+
atr_vals[atr_type]),
516+
getattr(self.data[key].table[:, atr.name],
517+
atr_vals[atr_type]),
542518
type(var)):
543519
new_atrs[atr.name][0] = True
544520
break
@@ -569,8 +545,9 @@ def merge_vars(new_atrs, atr):
569545
container = reduce(merge_vars, atrs, container)
570546
var_dict[atr_type] = container
571547
annotated = self.extract_new_table(var_dict)
548+
572549
for atr in annotated.domain.attributes:
573-
atr.attributes['Selected'] = atr in selected.domain.attributes
550+
atr.attributes['Selected'] = selected and atr in selected.domain.attributes
574551

575552
return selected, annotated
576553

@@ -609,13 +586,13 @@ def match(val):
609586
names = [itemset.title.strip() for itemset in self.itemsets.values()]
610587
names = uniquify(names)
611588

612-
for i, (key, input) in enumerate(self.data.items()):
589+
for i, input in enumerate(self.data.values()):
613590
# cell vars are in functions that are only used in the loop
614591
# pylint: disable=cell-var-from-loop
615592
if not len(input.table):
616593
continue
617594
if self.useidentifiers:
618-
attr = self.itemsetAttr(key)
595+
attr = self.itemsetAttr()
619596
if attr is not None:
620597
mask = list(map(match, (inst[attr] for inst in input.table)))
621598
else:
@@ -655,7 +632,7 @@ def instance_key_all(inst):
655632
annotated_data_subsets.append(annotated_subset)
656633
annotated_data_masks.append(mask)
657634

658-
if len(subset) == 0:
635+
if not subset:
659636
continue
660637

661638
# add columns with source table id and set id
@@ -1719,6 +1696,25 @@ def group_table_indices(table, key_var):
17191696
groups[str(inst[key_var])].append(i)
17201697
return groups
17211698

1699+
def arrays_equal(a, b, type_):
1700+
"""
1701+
checks if arrays have nans in same places and if not-nan elements
1702+
are equal
1703+
"""
1704+
if a is None and b is None:
1705+
return True
1706+
if a is None or b is None:
1707+
return False
1708+
if type_ is not StringVariable:
1709+
if not numpy.all(numpy.isnan(a) == numpy.isnan(b)):
1710+
return False
1711+
if not numpy.any(a[numpy.logical_not(numpy.isnan(a))] == b[numpy.logical_not(numpy.isnan(b))]):
1712+
return False
1713+
return True
1714+
else:
1715+
if not(a == b).all():
1716+
return False
1717+
return True
17221718

17231719
if __name__ == "__main__": # pragma: no cover
17241720
from Orange.evaluation import ShuffleSplit

0 commit comments

Comments
 (0)