1111from Orange .clustering .kmeans import KMeansModel , SILHOUETTE_MAX_SAMPLES
1212from Orange .data import Table , Domain , DiscreteVariable , ContinuousVariable
1313from Orange .data .util import get_unique_names
14+ from Orange .preprocess .impute import ReplaceUnknowns
1415from Orange .widgets import widget , gui
1516from Orange .widgets .settings import Setting
1617from Orange .widgets .utils .annotated_data import \
@@ -33,7 +34,8 @@ def __init__(self, parent=None):
3334 def rowCount (self , index = QModelIndex ()):
3435 return 0 if index .isValid () else len (self .scores )
3536
36- def columnCount (self , index = QModelIndex ()):
37+ @staticmethod
38+ def columnCount (_index = QModelIndex ()):
3739 return 1
3840
3941 def flags (self , index ):
@@ -64,10 +66,12 @@ def data(self, index, role=Qt.DisplayRole):
6466 return score
6567 elif role == gui .BarRatioRole and valid :
6668 return score
69+ return None
6770
68- def headerData (self , row , orientation , role = Qt .DisplayRole ):
71+ def headerData (self , row , _orientation , role = Qt .DisplayRole ):
6972 if role == Qt .DisplayRole :
7073 return str (row + self .start_k )
74+ return None
7175
7276
7377class Task :
@@ -443,8 +447,9 @@ def update_results(self):
443447
444448 def selected_row (self ):
445449 indices = self .table_view .selectedIndexes ()
446- if indices :
447- return indices [0 ].row ()
450+ if not indices :
451+ return None
452+ return indices [0 ].row ()
448453
449454 def select_row (self ):
450455 self .send_data ()
@@ -468,21 +473,49 @@ def send_data(self):
468473 values = ["C%d" % (x + 1 ) for x in range (km .k )]
469474 )
470475 clust_ids = km (self .data )
476+ clust_col = clust_ids .X .ravel ()
471477 silhouette_var = ContinuousVariable (
472478 get_unique_names (domain , "Silhouette" ))
473479 if km .silhouette_samples is not None :
474480 self .Warning .no_silhouettes .clear ()
475481 scores = np .arctan (km .silhouette_samples ) / np .pi + 0.5
482+ clust_scores = []
483+ for i in range (km .k ):
484+ in_clust = clust_col == i
485+ if in_clust .any ():
486+ clust_scores .append (np .mean (scores [in_clust ]))
487+ else :
488+ clust_scores .append (0. )
489+ clust_scores = np .atleast_2d (clust_scores ).T
476490 else :
477491 self .Warning .no_silhouettes ()
478492 scores = np .nan
493+ clust_scores = np .full ((km .k , 1 ), np .nan )
479494
480495 new_domain = add_columns (domain , metas = [cluster_var , silhouette_var ])
481496 new_table = self .data .transform (new_domain )
482- new_table .get_column_view (cluster_var )[0 ][:] = clust_ids . X . ravel ()
497+ new_table .get_column_view (cluster_var )[0 ][:] = clust_col
483498 new_table .get_column_view (silhouette_var )[0 ][:] = scores
484499
485- centroids = Table (Domain (km .pre_domain .attributes ), km .centroids )
500+ centroid_attributes = [
501+ attr .compute_value .variable
502+ if isinstance (attr .compute_value , ReplaceUnknowns )
503+ and attr .compute_value .variable in domain .attributes
504+ else attr
505+ for attr in km .pre_domain .attributes ]
506+ centroid_domain = add_columns (
507+ Domain (centroid_attributes , [], domain .metas ),
508+ metas = [cluster_var , silhouette_var ])
509+ centroids = Table (
510+ centroid_domain , km .centroids , None ,
511+ np .hstack ((np .full ((km .k , len (domain .metas )), np .nan ),
512+ np .arange (km .k ).reshape (km .k , 1 ),
513+ clust_scores ))
514+ )
515+ if self .data .name == Table .name :
516+ centroids .name = "centroids"
517+ else :
518+ centroids .name = f"{ self .data .name } centroids"
486519
487520 self .Outputs .annotated_data .send (new_table )
488521 self .Outputs .centroids .send (centroids )
0 commit comments