Skip to content

Commit acde018

Browse files
committed
Concatenate: Refactor merge_domain, prevent duplicated names
1 parent ea757b6 commit acde018

File tree

2 files changed

+237
-68
lines changed

2 files changed

+237
-68
lines changed

Orange/widgets/data/owconcatenate.py

Lines changed: 80 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@
66
77
"""
88

9-
from collections import OrderedDict
9+
from collections import OrderedDict, namedtuple
1010
from functools import reduce
11+
from itertools import chain, count
1112
from typing import List
1213

1314
import numpy as np
1415
from AnyQt.QtWidgets import QFormLayout
1516
from AnyQt.QtCore import Qt
1617

1718
import Orange.data
19+
from Orange.data.util import get_unique_names_duplicates
1820
from Orange.util import flatten
1921
from Orange.widgets import widget, gui, settings
2022
from Orange.widgets.settings import Setting
@@ -44,6 +46,10 @@ class Outputs:
4446
class Error(widget.OWWidget.Error):
4547
bow_concatenation = Msg("Inputs must be of the same type.")
4648

49+
class Warning(widget.OWWidget.Warning):
50+
renamed_variables = Msg(
51+
"Variables with duplicated names have been renamed.")
52+
4753
merge_type: int
4854
append_source_column: bool
4955
source_column_role: int
@@ -173,18 +179,17 @@ def incompatible_types(self):
173179
return False
174180

175181
def apply(self):
182+
self.Warning.renamed_variables.clear()
176183
tables, domain, source_var = [], None, None
177184
if self.primary_data is not None:
178185
tables = [self.primary_data] + list(self.more_data.values())
179186
domain = self.primary_data.domain
180187
elif self.more_data:
181188
tables = self.more_data.values()
182-
if self.merge_type == OWConcatenate.MergeUnion:
183-
domain = reduce(domain_union,
184-
(table.domain for table in tables))
185-
else:
186-
domain = reduce(domain_intersection,
187-
(table.domain for table in tables))
189+
domains = [table.domain for table in tables]
190+
oper = set.union if self.merge_type == OWConcatenate.MergeUnion \
191+
else set.intersection
192+
domain = self.merge_domains(domains, oper)
188193

189194
if tables and self.append_source_column:
190195
assert domain is not None
@@ -237,56 +242,74 @@ def send_report(self):
237242
self.id_roles[self.source_column_role].lower())
238243
self.report_items(items)
239244

240-
241-
def unique(seq: List[Orange.data.Variable]):
242-
attrs = {}
243-
for el in seq:
244-
if el not in attrs:
245-
attrs[el] = el, True
246-
continue
247-
if el.is_string:
248-
continue
249-
attr, orig_attr = attrs.get(el)
250-
if el.is_discrete:
251-
sel_values = set(el.values)
252-
sattr_values = set(attr.values)
253-
if orig_attr and sel_values != sattr_values:
254-
del attrs[attr]
255-
attr = attr.copy()
256-
attrs[attr] = attr, False
257-
for val in sel_values:
258-
if val not in attr.values: # don't use sets: keep the order
259-
attr.add_value(val)
260-
else: # ContinuousVariable
261-
num_dec = max(attr.number_of_decimals, el.number_of_decimals)
262-
if orig_attr and num_dec < el.number_of_decimals:
263-
del attrs[attr]
264-
attr = attr.copy(number_of_decimals=num_dec)
265-
attrs[attr] = attr, False
266-
return list(attrs)
267-
268-
269-
def domain_union(a, b):
270-
union = Orange.data.Domain(
271-
tuple(unique(a.attributes + b.attributes)),
272-
tuple(unique(a.class_vars + b.class_vars)),
273-
tuple(unique(a.metas + b.metas))
274-
)
275-
return union
276-
277-
278-
def domain_intersection(a, b):
279-
def tuple_intersection(t1, t2):
280-
inters = set(t1) & set(t2)
281-
return tuple(unique(el for el in t1 + t2 if el in inters))
282-
283-
intersection = Orange.data.Domain(
284-
tuple_intersection(a.attributes, b.attributes),
285-
tuple_intersection(a.class_vars, b.class_vars),
286-
tuple_intersection(a.metas, b.metas),
287-
)
288-
289-
return intersection
245+
def merge_domains(self, domains, oper):
246+
def fix_names(part):
247+
for i, attr, name in zip(count(), part, name_iter):
248+
if attr.name != name:
249+
part[i] = attr.renamed(name)
250+
self.Warning.renamed_variables()
251+
252+
parts = [self._get_part(domains, oper, part)
253+
for part in ("attributes", "class_vars", "metas")]
254+
all_names = [var.name for var in chain(*parts)]
255+
name_iter = iter(get_unique_names_duplicates(all_names))
256+
for part in parts:
257+
fix_names(part)
258+
domain = Orange.data.Domain(*parts)
259+
return domain
260+
261+
@classmethod
262+
def _get_part(cls, domains, oper, part):
263+
# keep the order of variables: first compute union or intersections as
264+
# sets, then iterate through chained parts
265+
vars_by_domain = [getattr(domain, part) for domain in domains]
266+
valid = reduce(oper, map(set, vars_by_domain))
267+
valid_vars = [var for var in chain(*vars_by_domain) if var in valid]
268+
return cls._unique_vars(valid_vars)
269+
270+
@staticmethod
271+
def _unique_vars(seq: List[Orange.data.Variable]):
272+
AttrDesc = namedtuple(
273+
"AttrDesc",
274+
("template", "original", "values", "number_of_decimals"))
275+
276+
attrs = {}
277+
for el in seq:
278+
desc = attrs.get(el)
279+
if desc is None:
280+
attrs[el] = AttrDesc(el, True,
281+
el.is_discrete and el.values,
282+
el.is_continuous and el.number_of_decimals)
283+
continue
284+
if desc.template.is_discrete:
285+
sattr_values = set(desc.values)
286+
# don't use sets: keep the order
287+
missing_values = [val for val in el.values
288+
if val not in sattr_values]
289+
if missing_values:
290+
attrs[el] = attrs[el]._replace(
291+
original=False,
292+
values=desc.values + missing_values)
293+
elif desc.template.is_continuous:
294+
if el.number_of_decimals > desc.number_of_decimals:
295+
attrs[el] = attrs[el]._replace(
296+
original=False,
297+
number_of_decimals=el.number_of_decimals)
298+
299+
new_attrs = []
300+
for desc in attrs.values():
301+
attr = desc.template
302+
if desc.original:
303+
new_attr = attr
304+
elif desc.template.is_discrete:
305+
new_attr = attr.copy()
306+
for val in desc.values[len(attr.values):]:
307+
new_attr.add_value(val)
308+
else:
309+
assert desc.template.is_continuous
310+
new_attr = attr.copy(number_of_decimals=desc.number_of_decimals)
311+
new_attrs.append(new_attr)
312+
return new_attrs
290313

291314

292315
if __name__ == "__main__": # pragma: no cover

0 commit comments

Comments
 (0)