@@ -166,3 +166,34 @@ class BaseRegressionMethod(BaseMethod):
166166class BaseClusteringMethod (BaseMethod ):
167167
168168 _DEFAULT_METRIC = "ari"
169+
170+ def score (self , x , y , * , score_func : Optional [Union [str , Mapping [Any , float ]]] = None , return_pred : bool = False ,
171+ valid_idx = None , test_idx = None ) -> Union [float , Tuple [float , Any ]]:
172+ y_pred = self .predict (x )
173+ func = resolve_score_func (score_func or self ._DEFAULT_METRIC )
174+ if valid_idx is None :
175+ score = func (y , y_pred )
176+ return (score , y_pred ) if return_pred else score
177+ else :
178+ valid_score = func ([y [i ] for i in valid_idx ], [y_pred [i ] for i in valid_idx ])
179+ test_score = func ([y [i ] for i in test_idx ], [y_pred [i ] for i in test_idx ])
180+ return ({
181+ "valid_score" : valid_score ,
182+ "test_score" : test_score
183+ }, y_pred ) if return_pred else {
184+ "valid_score" : valid_score ,
185+ "test_score" : test_score
186+ }
187+
188+ def fit_score (self , x , y , * , score_func : Optional [Union [str , Mapping [Any ,
189+ float ]]] = None , return_pred : bool = False ,
190+ valid_idx = None , test_idx = None , ** fit_kwargs ) -> Union [float , Tuple [float , Any ]]:
191+ """Shortcut for fitting data using the input feature and return eval.
192+
193+ Note
194+ ----
195+ Only work for models where the fitting does not require labeled data, i.e. unsupervised methods.
196+
197+ """
198+ self .fit (x , ** fit_kwargs )
199+ return self .score (x , y , score_func = score_func , return_pred = return_pred , valid_idx = valid_idx , test_idx = test_idx )
0 commit comments