Skip to content

Commit f632649

Browse files
committed
OWConcatenate: Create new variables when they differ, prevent duplicated names
1 parent 60f070e commit f632649

File tree

2 files changed

+256
-47
lines changed

2 files changed

+256
-47
lines changed

Orange/widgets/data/owconcatenate.py

Lines changed: 81 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66
77
"""
88

9-
from collections import OrderedDict
9+
from collections import OrderedDict, namedtuple
1010
from functools import reduce
11+
from itertools import chain, count
12+
from typing import List
1113

1214
import numpy as np
1315
from AnyQt.QtWidgets import QFormLayout
1416
from AnyQt.QtCore import Qt
1517

1618
import Orange.data
19+
from Orange.data.util import get_unique_names_duplicates
1720
from Orange.util import flatten
1821
from Orange.widgets import widget, gui, settings
1922
from Orange.widgets.settings import Setting
@@ -43,6 +46,10 @@ class Outputs:
4346
class Error(widget.OWWidget.Error):
4447
bow_concatenation = Msg("Inputs must be of the same type.")
4548

49+
class Warning(widget.OWWidget.Warning):
50+
renamed_variables = Msg(
51+
"Variables with duplicated names have been renamed.")
52+
4653
merge_type: int
4754
append_source_column: bool
4855
source_column_role: int
@@ -172,18 +179,15 @@ def incompatible_types(self):
172179
return False
173180

174181
def apply(self):
182+
self.Warning.renamed_variables.clear()
175183
tables, domain, source_var = [], None, None
176184
if self.primary_data is not None:
177185
tables = [self.primary_data] + list(self.more_data.values())
178186
domain = self.primary_data.domain
179187
elif self.more_data:
180188
tables = self.more_data.values()
181-
if self.merge_type == OWConcatenate.MergeUnion:
182-
domain = reduce(domain_union,
183-
(table.domain for table in tables))
184-
else:
185-
domain = reduce(domain_intersection,
186-
(table.domain for table in tables))
189+
domains = [table.domain for table in tables]
190+
domain = self.merge_domains(domains)
187191

188192
if tables and self.append_source_column:
189193
assert domain is not None
@@ -236,36 +240,76 @@ def send_report(self):
236240
self.id_roles[self.source_column_role].lower())
237241
self.report_items(items)
238242

239-
240-
def unique(seq):
241-
seen_set = set()
242-
for el in seq:
243-
if el not in seen_set:
244-
yield el
245-
seen_set.add(el)
246-
247-
248-
def domain_union(a, b):
249-
union = Orange.data.Domain(
250-
tuple(unique(a.attributes + b.attributes)),
251-
tuple(unique(a.class_vars + b.class_vars)),
252-
tuple(unique(a.metas + b.metas))
253-
)
254-
return union
255-
256-
257-
def domain_intersection(a, b):
258-
def tuple_intersection(t1, t2):
259-
inters = set(t1) & set(t2)
260-
return tuple(unique(el for el in t1 + t2 if el in inters))
261-
262-
intersection = Orange.data.Domain(
263-
tuple_intersection(a.attributes, b.attributes),
264-
tuple_intersection(a.class_vars, b.class_vars),
265-
tuple_intersection(a.metas, b.metas),
266-
)
267-
268-
return intersection
243+
def merge_domains(self, domains):
244+
def fix_names(part):
245+
for i, attr, name in zip(count(), part, name_iter):
246+
if attr.name != name:
247+
part[i] = attr.renamed(name)
248+
self.Warning.renamed_variables()
249+
250+
oper = set.union if self.merge_type == OWConcatenate.MergeUnion \
251+
else set.intersection
252+
parts = [self._get_part(domains, oper, part)
253+
for part in ("attributes", "class_vars", "metas")]
254+
all_names = [var.name for var in chain(*parts)]
255+
name_iter = iter(get_unique_names_duplicates(all_names))
256+
for part in parts:
257+
fix_names(part)
258+
domain = Orange.data.Domain(*parts)
259+
return domain
260+
261+
@classmethod
262+
def _get_part(cls, domains, oper, part):
263+
# keep the order of variables: first compute union or intersections as
264+
# sets, then iterate through chained parts
265+
vars_by_domain = [getattr(domain, part) for domain in domains]
266+
valid = reduce(oper, map(set, vars_by_domain))
267+
valid_vars = [var for var in chain(*vars_by_domain) if var in valid]
268+
return cls._unique_vars(valid_vars)
269+
270+
@staticmethod
271+
def _unique_vars(seq: List[Orange.data.Variable]):
272+
AttrDesc = namedtuple(
273+
"AttrDesc",
274+
("template", "original", "values", "number_of_decimals"))
275+
276+
attrs = {}
277+
for el in seq:
278+
desc = attrs.get(el)
279+
if desc is None:
280+
attrs[el] = AttrDesc(el, True,
281+
el.is_discrete and el.values,
282+
el.is_continuous and el.number_of_decimals)
283+
continue
284+
if desc.template.is_discrete:
285+
sattr_values = set(desc.values)
286+
# don't use sets: keep the order
287+
missing_values = [val for val in el.values
288+
if val not in sattr_values]
289+
if missing_values:
290+
attrs[el] = attrs[el]._replace(
291+
original=False,
292+
values=desc.values + missing_values)
293+
elif desc.template.is_continuous:
294+
if el.number_of_decimals > desc.number_of_decimals:
295+
attrs[el] = attrs[el]._replace(
296+
original=False,
297+
number_of_decimals=el.number_of_decimals)
298+
299+
new_attrs = []
300+
for desc in attrs.values():
301+
attr = desc.template
302+
if desc.original:
303+
new_attr = attr
304+
elif desc.template.is_discrete:
305+
new_attr = attr.copy()
306+
for val in desc.values[len(attr.values):]:
307+
new_attr.add_value(val)
308+
else:
309+
assert desc.template.is_continuous
310+
new_attr = attr.copy(number_of_decimals=desc.number_of_decimals)
311+
new_attrs.append(new_attr)
312+
return new_attrs
269313

270314

271315
if __name__ == "__main__": # pragma: no cover

0 commit comments

Comments
 (0)