99
1010import torch
1111
12+ from autointent import Dataset
1213from autointent .context import Context
1314from autointent .custom_types import NodeType
1415from autointent .modules .abc import Module
@@ -31,7 +32,7 @@ def __init__(
3132
3233 :param node_type: Node type
3334 :param search_space: Search space for the optimization
34- :param metric: Metric to optimize.
35+ :param metrics: Metrics to optimize.
3536 """
3637 self .node_type = node_type
3738 self .node_info = NODES_INFO [node_type ]
@@ -41,7 +42,7 @@ def __init__(
4142 if self .target_metric not in self .metrics :
4243 self .metrics .append (self .target_metric )
4344
44- self .modules_search_spaces = search_space # TODO search space validation
45+ self .modules_search_spaces = search_space
4546 self ._logger = logging .getLogger (__name__ ) # TODO solve duplicate logging messages problem
4647
4748 def fit (self , context : Context ) -> None :
@@ -143,3 +144,25 @@ def module_fit(self, module: Module, context: Context) -> None:
143144 self ._logger .error (msg )
144145 raise ValueError (msg )
145146 module .fit (* args ) # type: ignore[arg-type]
147+
148+ def validate_nodes_with_dataset (self , dataset : Dataset ) -> None :
149+ """
150+ Validate nodes with dataset.
151+
152+ :param dataset: Dataset to use
153+ """
154+ is_multilabel = dataset .multilabel
155+
156+ for search_space in deepcopy (self .modules_search_spaces ):
157+ module_name = search_space .pop ("module_name" )
158+ module = self .node_info .modules_available [module_name ]
159+ # todo add check for oos
160+
161+ if is_multilabel and not module .supports_multilabel :
162+ msg = f"Module '{ module_name } ' does not support multilabel datasets."
163+ self ._logger .error (msg )
164+ raise ValueError (msg )
165+ if not is_multilabel and not module .supports_multiclass :
166+ msg = f"Module '{ module_name } ' does not support multiclass datasets."
167+ self ._logger .error (msg )
168+ raise ValueError (msg )
0 commit comments