@@ -181,10 +181,11 @@ def query(self, X_pool, return_metrics: bool = False, *query_args, **query_kwarg
181181 query_metrics = None
182182 query_result = self .query_strategy (
183183 self , X_pool , * query_args , ** query_kwargs )
184- warnings .warn (
185- "The selected query strategy doesn't support return_metrics" )
186184
187185 if return_metrics :
186+ if query_metrics is None :
187+ warnings .warn (
188+ "The selected query strategy doesn't support return_metrics" )
188189 return query_result , retrieve_rows (X_pool , query_result ), query_metrics
189190 else :
190191 return query_result , retrieve_rows (X_pool , query_result )
@@ -246,6 +247,10 @@ def _fit_on_new(self, X: modALinput, y: modALinput, bootstrap: bool = False, **f
246247 for learner in self .learner_list :
247248 learner ._fit_on_new (X , y , bootstrap = bootstrap , ** fit_kwargs )
248249
250+ @abc .abstractmethod
251+ def teach (self , X : modALinput , y : modALinput , bootstrap : bool = False , ** fit_kwargs ) -> Any :
252+ pass
253+
249254 @abc .abstractmethod
250255 def predict (self , X : modALinput ) -> Any :
251256 pass
@@ -288,10 +293,11 @@ def query(self, X_pool, return_metrics: bool = False, *query_args, **query_kwarg
288293 query_metrics = None
289294 query_result = self .query_strategy (
290295 self , X_pool , * query_args , ** query_kwargs )
291- warnings .warn (
292- "The selected query strategy doesn't support return_metrics" )
293296
294297 if return_metrics :
298+ if query_metrics is None :
299+ warnings .warn (
300+ "The selected query strategy doesn't support return_metrics" )
295301 return query_result , retrieve_rows (X_pool , query_result ), query_metrics
296302 else :
297303 return query_result , retrieve_rows (X_pool , query_result )
0 commit comments