@@ -241,7 +241,7 @@ def get_module_dump_dir(self, context: Context, module_name: str, j_combination:
241241 dump_dir_ .mkdir (parents = True , exist_ok = True )
242242 return str (dump_dir_ )
243243
244- def validate_nodes_with_dataset (self , dataset : Dataset , mode : SearchSpaceValidationMode ) -> None :
244+ def validate_nodes_with_dataset (self , dataset : Dataset , mode : SearchSpaceValidationMode ) -> None : # noqa: C901
245245 """Validates nodes against the dataset.
246246
247247 Args:
@@ -254,12 +254,24 @@ def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidat
254254 is_multilabel = dataset .multilabel
255255
256256 filtered_search_space = []
257+ if is_multilabel and self .target_metric not in self .node_info .multilabel_available_metrics :
258+ handle_message_on_mode (
259+ mode , f"Target metric '{ self .target_metric } ' is not available for multilabel datasets." , True
260+ )
261+ elif not is_multilabel and self .target_metric not in self .node_info .multiclass_available_metrics :
262+ handle_message_on_mode (
263+ mode , f"Target metric '{ self .target_metric } ' is not available for multiclass datasets." , True
264+ )
265+
266+ for metric in self .metrics :
267+ if is_multilabel and metric not in self .node_info .multilabel_available_metrics :
268+ handle_message_on_mode (mode , f"Metric '{ metric } ' is not available for multilabel datasets." , True )
269+ elif not is_multilabel and metric not in self .node_info .multiclass_available_metrics :
270+ handle_message_on_mode (mode , f"Metric '{ metric } ' is not available for multiclass datasets." , True )
257271
258272 for search_space in deepcopy (self .modules_search_spaces ):
259273 module_name = search_space ["module_name" ]
260274 module = self .node_info .modules_available [module_name ]
261- # todo add check for oos
262-
263275 messages = []
264276
265277 if module_name == "description" and not dataset .has_descriptions :
@@ -273,11 +285,7 @@ def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidat
273285
274286 if len (messages ) > 0 :
275287 msg = "\n " .join (messages )
276- if mode == "raise" :
277- self ._logger .error (msg )
278- raise ValueError (msg )
279- if mode == "warning" :
280- self ._logger .warning (msg )
288+ handle_message_on_mode (mode , msg )
281289 else :
282290 filtered_search_space .append (search_space )
283291
@@ -393,3 +401,26 @@ def load_or_create_study(
393401 finished_trials ,
394402 remaining_trials ,
395403 )
404+
405+
406+ def handle_message_on_mode (
407+ mode : SearchSpaceValidationMode ,
408+ message : str ,
409+ strict : bool = False ,
410+ ) -> None :
411+ """Handle messages based on the validation mode.
412+
413+ Args:
414+ mode: The validation mode ("raise" or "warning").
415+ message: The message to handle.
416+ strict: If True always raises an error, even if mode is "warning".
417+
418+ Raises:
419+ ValueError: If mode is "raise".
420+ """
421+ if mode == "raise" :
422+ raise ValueError (message )
423+ if mode == "warning" :
424+ logger .warning (message )
425+ if strict :
426+ raise ValueError (message )
0 commit comments