|
6 | 6 |
|
7 | 7 | """ |
8 | 8 |
|
9 | | -from collections import OrderedDict |
| 9 | +from collections import OrderedDict, namedtuple |
10 | 10 | from functools import reduce |
| 11 | +from itertools import chain, count |
| 12 | +from typing import List |
11 | 13 |
|
12 | 14 | import numpy as np |
13 | 15 | from AnyQt.QtWidgets import QFormLayout |
14 | 16 | from AnyQt.QtCore import Qt |
15 | 17 |
|
16 | 18 | import Orange.data |
| 19 | +from Orange.data.util import get_unique_names_duplicates |
17 | 20 | from Orange.util import flatten |
18 | 21 | from Orange.widgets import widget, gui, settings |
19 | 22 | from Orange.widgets.settings import Setting |
@@ -43,6 +46,10 @@ class Outputs: |
43 | 46 | class Error(widget.OWWidget.Error): |
44 | 47 | bow_concatenation = Msg("Inputs must be of the same type.") |
45 | 48 |
|
| 49 | + class Warning(widget.OWWidget.Warning): |
| 50 | + renamed_variables = Msg( |
| 51 | + "Variables with duplicated names have been renamed.") |
| 52 | + |
46 | 53 | merge_type: int |
47 | 54 | append_source_column: bool |
48 | 55 | source_column_role: int |
@@ -172,18 +179,15 @@ def incompatible_types(self): |
172 | 179 | return False |
173 | 180 |
|
174 | 181 | def apply(self): |
| 182 | + self.Warning.renamed_variables.clear() |
175 | 183 | tables, domain, source_var = [], None, None |
176 | 184 | if self.primary_data is not None: |
177 | 185 | tables = [self.primary_data] + list(self.more_data.values()) |
178 | 186 | domain = self.primary_data.domain |
179 | 187 | elif self.more_data: |
180 | 188 | tables = self.more_data.values() |
181 | | - if self.merge_type == OWConcatenate.MergeUnion: |
182 | | - domain = reduce(domain_union, |
183 | | - (table.domain for table in tables)) |
184 | | - else: |
185 | | - domain = reduce(domain_intersection, |
186 | | - (table.domain for table in tables)) |
| 189 | + domains = [table.domain for table in tables] |
| 190 | + domain = self.merge_domains(domains) |
187 | 191 |
|
188 | 192 | if tables and self.append_source_column: |
189 | 193 | assert domain is not None |
@@ -236,36 +240,76 @@ def send_report(self): |
236 | 240 | self.id_roles[self.source_column_role].lower()) |
237 | 241 | self.report_items(items) |
238 | 242 |
|
239 | | - |
240 | | -def unique(seq): |
241 | | - seen_set = set() |
242 | | - for el in seq: |
243 | | - if el not in seen_set: |
244 | | - yield el |
245 | | - seen_set.add(el) |
246 | | - |
247 | | - |
248 | | -def domain_union(a, b): |
249 | | - union = Orange.data.Domain( |
250 | | - tuple(unique(a.attributes + b.attributes)), |
251 | | - tuple(unique(a.class_vars + b.class_vars)), |
252 | | - tuple(unique(a.metas + b.metas)) |
253 | | - ) |
254 | | - return union |
255 | | - |
256 | | - |
257 | | -def domain_intersection(a, b): |
258 | | - def tuple_intersection(t1, t2): |
259 | | - inters = set(t1) & set(t2) |
260 | | - return tuple(unique(el for el in t1 + t2 if el in inters)) |
261 | | - |
262 | | - intersection = Orange.data.Domain( |
263 | | - tuple_intersection(a.attributes, b.attributes), |
264 | | - tuple_intersection(a.class_vars, b.class_vars), |
265 | | - tuple_intersection(a.metas, b.metas), |
266 | | - ) |
267 | | - |
268 | | - return intersection |
| 243 | + def merge_domains(self, domains): |
| 244 | + def fix_names(part): |
| 245 | + for i, attr, name in zip(count(), part, name_iter): |
| 246 | + if attr.name != name: |
| 247 | + part[i] = attr.renamed(name) |
| 248 | + self.Warning.renamed_variables() |
| 249 | + |
| 250 | + oper = set.union if self.merge_type == OWConcatenate.MergeUnion \ |
| 251 | + else set.intersection |
| 252 | + parts = [self._get_part(domains, oper, part) |
| 253 | + for part in ("attributes", "class_vars", "metas")] |
| 254 | + all_names = [var.name for var in chain(*parts)] |
| 255 | + name_iter = iter(get_unique_names_duplicates(all_names)) |
| 256 | + for part in parts: |
| 257 | + fix_names(part) |
| 258 | + domain = Orange.data.Domain(*parts) |
| 259 | + return domain |
| 260 | + |
| 261 | + @classmethod |
| 262 | + def _get_part(cls, domains, oper, part): |
| 263 | + # keep the order of variables: first compute union or intersections as |
| 264 | + # sets, then iterate through chained parts |
| 265 | + vars_by_domain = [getattr(domain, part) for domain in domains] |
| 266 | + valid = reduce(oper, map(set, vars_by_domain)) |
| 267 | + valid_vars = [var for var in chain(*vars_by_domain) if var in valid] |
| 268 | + return cls._unique_vars(valid_vars) |
| 269 | + |
| 270 | + @staticmethod |
| 271 | + def _unique_vars(seq: List[Orange.data.Variable]): |
| 272 | + AttrDesc = namedtuple( |
| 273 | + "AttrDesc", |
| 274 | + ("template", "original", "values", "number_of_decimals")) |
| 275 | + |
| 276 | + attrs = {} |
| 277 | + for el in seq: |
| 278 | + desc = attrs.get(el) |
| 279 | + if desc is None: |
| 280 | + attrs[el] = AttrDesc(el, True, |
| 281 | + el.is_discrete and el.values, |
| 282 | + el.is_continuous and el.number_of_decimals) |
| 283 | + continue |
| 284 | + if desc.template.is_discrete: |
| 285 | + sattr_values = set(desc.values) |
| 286 | + # don't use sets: keep the order |
| 287 | + missing_values = [val for val in el.values |
| 288 | + if val not in sattr_values] |
| 289 | + if missing_values: |
| 290 | + attrs[el] = attrs[el]._replace( |
| 291 | + original=False, |
| 292 | + values=desc.values + missing_values) |
| 293 | + elif desc.template.is_continuous: |
| 294 | + if el.number_of_decimals > desc.number_of_decimals: |
| 295 | + attrs[el] = attrs[el]._replace( |
| 296 | + original=False, |
| 297 | + number_of_decimals=el.number_of_decimals) |
| 298 | + |
| 299 | + new_attrs = [] |
| 300 | + for desc in attrs.values(): |
| 301 | + attr = desc.template |
| 302 | + if desc.original: |
| 303 | + new_attr = attr |
| 304 | + elif desc.template.is_discrete: |
| 305 | + new_attr = attr.copy() |
| 306 | + for val in desc.values[len(attr.values):]: |
| 307 | + new_attr.add_value(val) |
| 308 | + else: |
| 309 | + assert desc.template.is_continuous |
| 310 | + new_attr = attr.copy(number_of_decimals=desc.number_of_decimals) |
| 311 | + new_attrs.append(new_attr) |
| 312 | + return new_attrs |
269 | 313 |
|
270 | 314 |
|
271 | 315 | if __name__ == "__main__": # pragma: no cover |
|
0 commit comments