|
6 | 6 |
|
7 | 7 | """ |
8 | 8 |
|
9 | | -from collections import OrderedDict |
| 9 | +from collections import OrderedDict, namedtuple |
10 | 10 | from functools import reduce |
| 11 | +from itertools import chain, count |
11 | 12 | from typing import List |
12 | 13 |
|
13 | 14 | import numpy as np |
14 | 15 | from AnyQt.QtWidgets import QFormLayout |
15 | 16 | from AnyQt.QtCore import Qt |
16 | 17 |
|
17 | 18 | import Orange.data |
| 19 | +from Orange.data.util import get_unique_names_duplicates |
18 | 20 | from Orange.util import flatten |
19 | 21 | from Orange.widgets import widget, gui, settings |
20 | 22 | from Orange.widgets.settings import Setting |
@@ -44,6 +46,10 @@ class Outputs: |
44 | 46 | class Error(widget.OWWidget.Error): |
45 | 47 | bow_concatenation = Msg("Inputs must be of the same type.") |
46 | 48 |
|
| 49 | + class Warning(widget.OWWidget.Warning): |
| 50 | + renamed_variables = Msg( |
| 51 | + "Variables with duplicated names have been renamed.") |
| 52 | + |
47 | 53 | merge_type: int |
48 | 54 | append_source_column: bool |
49 | 55 | source_column_role: int |
@@ -173,18 +179,17 @@ def incompatible_types(self): |
173 | 179 | return False |
174 | 180 |
|
175 | 181 | def apply(self): |
| 182 | + self.Warning.renamed_variables.clear() |
176 | 183 | tables, domain, source_var = [], None, None |
177 | 184 | if self.primary_data is not None: |
178 | 185 | tables = [self.primary_data] + list(self.more_data.values()) |
179 | 186 | domain = self.primary_data.domain |
180 | 187 | elif self.more_data: |
181 | 188 | tables = self.more_data.values() |
182 | | - if self.merge_type == OWConcatenate.MergeUnion: |
183 | | - domain = reduce(domain_union, |
184 | | - (table.domain for table in tables)) |
185 | | - else: |
186 | | - domain = reduce(domain_intersection, |
187 | | - (table.domain for table in tables)) |
| 189 | + domains = [table.domain for table in tables] |
| 190 | + oper = set.union if self.merge_type == OWConcatenate.MergeUnion \ |
| 191 | + else set.intersection |
| 192 | + domain = self.merge_domains(domains, oper) |
188 | 193 |
|
189 | 194 | if tables and self.append_source_column: |
190 | 195 | assert domain is not None |
@@ -237,56 +242,74 @@ def send_report(self): |
237 | 242 | self.id_roles[self.source_column_role].lower()) |
238 | 243 | self.report_items(items) |
239 | 244 |
|
240 | | - |
241 | | -def unique(seq: List[Orange.data.Variable]): |
242 | | - attrs = {} |
243 | | - for el in seq: |
244 | | - if el not in attrs: |
245 | | - attrs[el] = el, True |
246 | | - continue |
247 | | - if el.is_string: |
248 | | - continue |
249 | | - attr, orig_attr = attrs.get(el) |
250 | | - if el.is_discrete: |
251 | | - sel_values = set(el.values) |
252 | | - sattr_values = set(attr.values) |
253 | | - if orig_attr and sel_values != sattr_values: |
254 | | - del attrs[attr] |
255 | | - attr = attr.copy() |
256 | | - attrs[attr] = attr, False |
257 | | - for val in sel_values: |
258 | | - if val not in attr.values: # don't use sets: keep the order |
259 | | - attr.add_value(val) |
260 | | - else: # ContinuousVariable |
261 | | - num_dec = max(attr.number_of_decimals, el.number_of_decimals) |
262 | | - if orig_attr and num_dec < el.number_of_decimals: |
263 | | - del attrs[attr] |
264 | | - attr = attr.copy(number_of_decimals=num_dec) |
265 | | - attrs[attr] = attr, False |
266 | | - return list(attrs) |
267 | | - |
268 | | - |
269 | | -def domain_union(a, b): |
270 | | - union = Orange.data.Domain( |
271 | | - tuple(unique(a.attributes + b.attributes)), |
272 | | - tuple(unique(a.class_vars + b.class_vars)), |
273 | | - tuple(unique(a.metas + b.metas)) |
274 | | - ) |
275 | | - return union |
276 | | - |
277 | | - |
278 | | -def domain_intersection(a, b): |
279 | | - def tuple_intersection(t1, t2): |
280 | | - inters = set(t1) & set(t2) |
281 | | - return tuple(unique(el for el in t1 + t2 if el in inters)) |
282 | | - |
283 | | - intersection = Orange.data.Domain( |
284 | | - tuple_intersection(a.attributes, b.attributes), |
285 | | - tuple_intersection(a.class_vars, b.class_vars), |
286 | | - tuple_intersection(a.metas, b.metas), |
287 | | - ) |
288 | | - |
289 | | - return intersection |
| 245 | + def merge_domains(self, domains, oper): |
| 246 | + def fix_names(part): |
| 247 | + for i, attr, name in zip(count(), part, name_iter): |
| 248 | + if attr.name != name: |
| 249 | + part[i] = attr.renamed(name) |
| 250 | + self.Warning.renamed_variables() |
| 251 | + |
| 252 | + parts = [self._get_part(domains, oper, part) |
| 253 | + for part in ("attributes", "class_vars", "metas")] |
| 254 | + all_names = [var.name for var in chain(*parts)] |
| 255 | + name_iter = iter(get_unique_names_duplicates(all_names)) |
| 256 | + for part in parts: |
| 257 | + fix_names(part) |
| 258 | + domain = Orange.data.Domain(*parts) |
| 259 | + return domain |
| 260 | + |
| 261 | + @classmethod |
| 262 | + def _get_part(cls, domains, oper, part): |
| 263 | + # keep the order of variables: first compute union or intersections as |
| 264 | + # sets, then iterate through chained parts |
| 265 | + vars_by_domain = [getattr(domain, part) for domain in domains] |
| 266 | + valid = reduce(oper, map(set, vars_by_domain)) |
| 267 | + valid_vars = [var for var in chain(*vars_by_domain) if var in valid] |
| 268 | + return cls._unique_vars(valid_vars) |
| 269 | + |
| 270 | + @staticmethod |
| 271 | + def _unique_vars(seq: List[Orange.data.Variable]): |
| 272 | + AttrDesc = namedtuple( |
| 273 | + "AttrDesc", |
| 274 | + ("template", "original", "values", "number_of_decimals")) |
| 275 | + |
| 276 | + attrs = {} |
| 277 | + for el in seq: |
| 278 | + desc = attrs.get(el) |
| 279 | + if desc is None: |
| 280 | + attrs[el] = AttrDesc(el, True, |
| 281 | + el.is_discrete and el.values, |
| 282 | + el.is_continuous and el.number_of_decimals) |
| 283 | + continue |
| 284 | + if desc.template.is_discrete: |
| 285 | + sattr_values = set(desc.values) |
| 286 | + # don't use sets: keep the order |
| 287 | + missing_values = [val for val in el.values |
| 288 | + if val not in sattr_values] |
| 289 | + if missing_values: |
| 290 | + attrs[el] = attrs[el]._replace( |
| 291 | + original=False, |
| 292 | + values=desc.values + missing_values) |
| 293 | + elif desc.template.is_continuous: |
| 294 | + if el.number_of_decimals > desc.number_of_decimals: |
| 295 | + attrs[el] = attrs[el]._replace( |
| 296 | + original=False, |
| 297 | + number_of_decimals=el.number_of_decimals) |
| 298 | + |
| 299 | + new_attrs = [] |
| 300 | + for desc in attrs.values(): |
| 301 | + attr = desc.template |
| 302 | + if desc.original: |
| 303 | + new_attr = attr |
| 304 | + elif desc.template.is_discrete: |
| 305 | + new_attr = attr.copy() |
| 306 | + for val in desc.values[len(attr.values):]: |
| 307 | + new_attr.add_value(val) |
| 308 | + else: |
| 309 | + assert desc.template.is_continuous |
| 310 | + new_attr = attr.copy(number_of_decimals=desc.number_of_decimals) |
| 311 | + new_attrs.append(new_attr) |
| 312 | + return new_attrs |
290 | 313 |
|
291 | 314 |
|
292 | 315 | if __name__ == "__main__": # pragma: no cover |
|
0 commit comments