@@ -59,6 +59,10 @@ class Error(widget.OWWidget.Error):
5959 instances_mismatch = Msg ("Data sets do not contain the same instances." )
6060 too_many_inputs = Msg ("Venn diagram accepts at most five datasets." )
6161
62+ class Warning (widget .OWWidget .Warning ):
63+ renamed_vars = Msg ("Some variables have been renamed "
64+ "to avoid duplicates.\n {}" )
65+
6266 selection : list
6367
6468 settingsHandler = settings .DomainContextHandler ()
@@ -73,6 +77,8 @@ class Error(widget.OWWidget.Error):
7377
7478 want_control_area = False
7579 graph_name = "scene"
80+ atr_types = ['attributes' , 'metas' , 'class_vars' ]
81+ atr_vals = {'metas' : 'metas' , 'attributes' : 'X' , 'class_vars' : 'Y' }
7682
7783 def __init__ (self ):
7884 super ().__init__ ()
@@ -347,18 +353,21 @@ def invalidateOutput(self):
347353
348354 def merge_data (self , domain , values ):
349355 X , metas , class_vars = None , None , None
356+ renamed = []
350357 for val in domain .values ():
351358 names = [var .name for var in val ]
352359 unique_names = get_unique_names_duplicates (names )
353- for n , u , var in zip (names , unique_names , val ):
360+ for n , u , idx , var in zip (names , unique_names , range ( len ( val )) , val ):
354361 if n != u :
355- var .name = u
356- #TODO: warning because of a weird clash?
357- if values ['attributes' ]:
362+ val [idx ] = var .copy (name = u )
363+ renamed .append (n )
364+ if renamed :
365+ self .Warning .renamed_vars (', ' .join (renamed ))
366+ if 'attributes' in values .keys ():
358367 X = np .hstack (values ['attributes' ])
359- if values [ 'metas' ] :
368+ if 'metas' in values . keys () :
360369 metas = np .hstack (values ['metas' ])
361- if values [ 'class_vars' ] :
370+ if 'class_vars' in values . keys () :
362371 class_vars = np .hstack (values ['class_vars' ])
363372 return Table .from_numpy (Domain (** domain ), X , class_vars , metas )
364373
@@ -380,7 +389,7 @@ def extract_new_table(self, var_dict):
380389 values [atr_type ].append (getattr (self .data [var_data [1 ][0 ][1 ]].table [:, var_name ], atr_vals [atr_type ]).reshape (- 1 , 1 ))
381390 return self .merge_data (domain , values )
382391
383- def curry_merge (self , table_key , atr_type , ids = None ):
392+ def curry_merge (self , table_key , atr_type , ids = None , selection = False ):
384393 if self .rowwise :
385394 check_equality = self .arrays_equal_rows
386395 else :
@@ -389,23 +398,27 @@ def curry_merge(self, table_key, atr_type, ids=None):
389398 def inner (new_atrs , atr ):
390399 """
391400 Atrs - list of variables we wish to merge
392- new_atrs - dictionary where key is old name, val
393- is [is_different:bool, table_keys:list])
401+ new_atrs - dictionary where key is old var, val
402+ is [is_different:bool, table_keys:list]), is_different is set to True,
403+ if we are outputing duplicates, but the value is arbitrary
394404 """
395405 atr_vals = {'metas' : 'metas' , 'attributes' : 'X' , 'class_vars' : 'Y' }
396- if atr .name in new_atrs .keys ():
397- if not new_atrs [atr .name ][0 ]:
398- for var , key in new_atrs [atr .name ][1 ]:
406+ if atr in new_atrs .keys ():
407+ if not selection and self .output_duplicates :
408+ #if output_duplicates, we just check if compute value is the same
409+ new_atrs [atr ][0 ] = True
410+ elif not new_atrs [atr ][0 ]:
411+ for var , key in new_atrs [atr ][1 ]:
399412 if not check_equality (table_key ,
400413 key ,
401414 atr .name ,
402415 atr_vals [atr_type ],
403416 type (var ), ids ):
404- new_atrs [atr . name ][0 ] = True
417+ new_atrs [atr ][0 ] = True
405418 break
406- new_atrs [atr . name ][1 ].append ((atr , table_key ))
419+ new_atrs [atr ][1 ].append ((atr , table_key ))
407420 else :
408- new_atrs [atr . name ] = [False , [(atr , table_key )]]
421+ new_atrs [atr ] = [False , [(atr , table_key )]]
409422 return new_atrs
410423 return inner
411424
@@ -468,7 +481,7 @@ def extract_rowwise(self, var_dict, ids=None, selection=False):
468481 is [is_different:bool, table_keys:list])
469482 ids: dict with ids for each table
470483 """
471- all_ids = list (reduce (set .union , [set (val .keys ()) for val in ids .values ()], set ()))
484+ all_ids = sorted ( list (reduce (set .union , [set (val .keys ()) for val in ids .values ()], set () )))
472485
473486 permutations = dict ()
474487 for table_key , dict_ in ids .items ():
@@ -479,8 +492,8 @@ def extract_rowwise(self, var_dict, ids=None, selection=False):
479492 atr_vals = {'metas' : 'metas' , 'attributes' : 'X' , 'class_vars' : 'Y' }
480493 for atr_type , vars_dict in var_dict .items ():
481494 for var_name , var_data in vars_dict .items ():
482- duplicated = var_data [0 ]
483- if duplicated :
495+ different = var_data [0 ]
496+ if different :
484497 #columns are different, copy all, rename them
485498 for var , table_key in var_data [1 ]:
486499 temp = self .data [table_key ].table
@@ -517,41 +530,92 @@ def extract_rowwise(self, var_dict, ids=None, selection=False):
517530
518531 def get_indices (self , table , selection ):
519532 """Returns mappings of ids (be it row id or string) to indices in tables"""
520- #TODO: refactor?
521533 if self .selected_feature :
522- items , ids = np .unique (getattr (table [:, self .selected_feature ], 'metas' ),
523- return_index = True )
524- if selection :
525- return OrderedDict ([(item , idx ) for item , idx in zip (items , ids )
526- if item in self .selected_items ])
527- return OrderedDict (zip (items , ids ))
534+ if self .output_duplicates and selection :
535+ items , inverse = np .unique (getattr (table [:, self .selected_feature ], 'metas' ),
536+ return_inverse = True )
537+ ids = [np .nonzero (inverse == idx )[0 ] for idx in range (len (items ))]
538+ else :
539+ items , ids = np .unique (getattr (table [:, self .selected_feature ], 'metas' ),
540+ return_index = True )
541+
542+ else :
543+ items = table .ids
544+ ids = range (len (table ))
545+
528546 if selection :
529- if not self .selected_items :
530- return None
531- return OrderedDict ([(idx , val ) for val , idx in zip (range (len (table .ids )), table .ids )
532- if idx in self .selected_items ])
533- return OrderedDict (zip (table .ids , range (len (table ))))
534-
535- def get_indices_to_match_by (self , selected_keys ):
536- selected , annotated = dict (), dict ()
537- for key , val in self .data .items ():
538- annotated [key ] = self .get_indices (val .table , None )
539- if self .selection and key in selected_keys :
540- selected [key ] = self .get_indices (val .table , self .selection )
541- return selected , annotated
547+ return OrderedDict ([(item , idx ) for item , idx in zip (items , ids )
548+ if item in self .selected_items ])
549+
550+ return OrderedDict (zip (items , ids ))
551+
552+ def get_indices_to_match_by (self , relevant_keys , selection = False ):
553+ dict_ = dict ()
554+ for key in relevant_keys :
555+ table = self .data [key ].table
556+ dict_ [key ] = self .get_indices (table , selection )
557+ return dict_
542558
543559 def create_from_rows (self , relevant_keys , relevant_ids , selection = False ):
544560 atr_types = ['attributes' , 'metas' , 'class_vars' ]
545561 var_dict = {}
546562 for atr_type in atr_types :
547563 container = {}
548564 for table_key in relevant_keys :
549- merge_vars = self .curry_merge (table_key , atr_type , relevant_ids )
565+ merge_vars = self .curry_merge (table_key , atr_type , relevant_ids , selection )
550566 atrs = getattr (self .data [table_key ].table .domain , atr_type )
551567 container = reduce (merge_vars , atrs , container )
552568 var_dict [atr_type ] = container
569+ if self .output_duplicates and not selection :
570+ return self .extract_rowwise_duplicates (var_dict , relevant_ids , relevant_keys )
553571 return self .extract_rowwise (var_dict , relevant_ids , selection )
554572
573+ def make_it_fit (self , a , b ):
574+ #TODO: rename function
575+ if a == b .shape :
576+ return b
577+ if a [1 ] == 1 :
578+ return np .atleast_2d (b ).T
579+ return np .atleast_2d (b )
580+
581+ def expand_tables (self , table , atrs , metas , cv ):
582+ exp = []
583+ for all_el , atr_type in zip ([atrs , metas , cv ], self .atr_types ):
584+ #TODO : pohendlaj manjakoče atr_type & columns
585+ cur_el = getattr (table .domain , atr_type )
586+ perm = get_perm (cur_el , all_el )
587+ array = np .empty ((len (table ), len (all_el )))
588+ array .fill (np .nan )
589+ b = getattr (table , self .atr_vals [atr_type ])
590+ array [:, perm ] = self .make_it_fit (array [:, perm ].shape , b )
591+ #array[:, perm] = np.atleast_2d(getattr(table, self.atr_vals[atr_type]))
592+ exp .append (array )
593+ #TODO: maybe this could be smarter
594+ return exp [0 ], exp [1 ], exp [2 ]
595+
596+ def extract_rowwise_duplicates (self , var_dict , ids , relevant_keys ):
597+ #za vsak id v vsakemu stolpcu rabimo indekse
598+ #extractamo celo podtabelo, vstavimo morebitne manjkajoče stolpce, na koncu vstack
599+ all_ids = sorted (list (reduce (set .union , [set (val .keys ()) for val in ids .values ()], set ())))
600+ sort_key = lambda var : var .name
601+ all_atrs = sorted ([var for var in var_dict ['attributes' ].keys ()], key = sort_key )
602+ all_metas = sorted ([var for var in var_dict ['metas' ].keys ()], key = sort_key )
603+ all_cv = sorted ([var for var in var_dict ['class_vars' ].keys ()], key = sort_key )
604+
605+ all_x , all_y , all_m = [], [], []
606+ for idx in all_ids :
607+ #iterate trough tables with same idx
608+ for table_key in relevant_keys :
609+ map_ = ids [table_key ][idx ]
610+ extracted = self .data [table_key ].table [map_ ]
611+ x , m , y = self .expand_tables (extracted , all_atrs , all_metas , all_cv )
612+ all_x .append (x )
613+ all_y .append (y )
614+ all_m .append (m )
615+ domain = {'attributes' : all_atrs , 'metas' : all_metas , 'class_vars' : all_cv }
616+ values = {'attributes' : [np .vstack (all_x )], 'metas' : [np .vstack (all_m )], 'class_vars' : [np .vstack (all_y )]}
617+ return self .merge_data (domain , values )
618+
555619 def commit (self ):
556620
557621 if not self .vennwidget .vennareas () or not self .data .keys ():
@@ -569,7 +633,8 @@ def commit(self):
569633 selected = None
570634
571635 if self .rowwise :
572- selected_ids , annotated_ids = self .get_indices_to_match_by (selected_keys )
636+ selected_ids = self .get_indices_to_match_by (selected_keys , self .selection )
637+ annotated_ids = self .get_indices_to_match_by (self .data .keys ())
573638 annotated = self .create_from_rows (self .data .keys (), annotated_ids , True )
574639 if self .selected_items :
575640 selected = self .create_from_rows (selected_keys , selected_ids , False )
0 commit comments