diff --git a/autointent/_dataset/_dataset.py b/autointent/_dataset/_dataset.py index d15ca8aee..760601aa1 100644 --- a/autointent/_dataset/_dataset.py +++ b/autointent/_dataset/_dataset.py @@ -146,6 +146,7 @@ def push_to_hub(self, repo_id: str, private: bool = False) -> None: Push dataset splits to a Hugging Face repository. :param repo_id: ID of the Hugging Face repository. + :param private: Whether the repository is private """ for split_name, split in self.items(): split.push_to_hub(repo_id, split=split_name, private=private) diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 2201575b0..3415b948e 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -138,7 +138,7 @@ def fit(self, dataset: Dataset) -> Context: context.configure_logging(self.logging_config) context.configure_vector_index(self.vector_index_config, self.embedder_config) context.configure_cross_encoder(self.cross_encoder_config) - + self.validate_modules(dataset) self._fit(context) if context.is_ram_to_clear(): @@ -160,6 +160,16 @@ def fit(self, dataset: Dataset) -> Context: return context + def validate_modules(self, dataset: Dataset) -> None: + """ + Validate modules with dataset. + + :param dataset: dataset to validate with + """ + for node in self.nodes.values(): + if isinstance(node, NodeOptimizer): + node.validate_nodes_with_dataset(dataset) + @classmethod def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> "Pipeline": """ diff --git a/autointent/generation/utterances/__init__.py b/autointent/generation/utterances/__init__.py index 7534eb1a0..006a163df 100644 --- a/autointent/generation/utterances/__init__.py +++ b/autointent/generation/utterances/__init__.py @@ -1,28 +1,28 @@ from .basic import SynthesizerChatTemplate, UtteranceGenerator from .evolution import ( - AbstractEvolution, - ConcreteEvolution, - EvolutionChatTemplate, - FormalEvolution, - FunnyEvolution, - GoofyEvolution, - InformalEvolution, - ReasoningEvolution, - UtteranceEvolver, + AbstractEvolution, + ConcreteEvolution, + EvolutionChatTemplate, + FormalEvolution, + FunnyEvolution, + GoofyEvolution, + InformalEvolution, + ReasoningEvolution, + UtteranceEvolver, ) from .generator import Generator __all__ = [ - "AbstractEvolution", - "ConcreteEvolution", - "EvolutionChatTemplate", - "FormalEvolution", - "FunnyEvolution", - "Generator", - "GoofyEvolution", - "InformalEvolution", - "ReasoningEvolution", - "SynthesizerChatTemplate", - "UtteranceEvolver", - "UtteranceGenerator", + "AbstractEvolution", + "ConcreteEvolution", + "EvolutionChatTemplate", + "FormalEvolution", + "FunnyEvolution", + "Generator", + "GoofyEvolution", + "InformalEvolution", + "ReasoningEvolution", + "SynthesizerChatTemplate", + "UtteranceEvolver", + "UtteranceGenerator", ] diff --git a/autointent/generation/utterances/evolution/__init__.py b/autointent/generation/utterances/evolution/__init__.py index 7e352bd86..596d83a3f 100644 --- a/autointent/generation/utterances/evolution/__init__.py +++ b/autointent/generation/utterances/evolution/__init__.py @@ -1,23 +1,23 @@ from .chat_templates import ( - AbstractEvolution, - ConcreteEvolution, - EvolutionChatTemplate, - FormalEvolution, - FunnyEvolution, - GoofyEvolution, - InformalEvolution, - ReasoningEvolution, + AbstractEvolution, + ConcreteEvolution, + EvolutionChatTemplate, + FormalEvolution, + FunnyEvolution, + GoofyEvolution, + InformalEvolution, + ReasoningEvolution, ) from .evolver import UtteranceEvolver __all__ = [ - "AbstractEvolution", - "ConcreteEvolution", - "EvolutionChatTemplate", - "FormalEvolution", - "FunnyEvolution", - "GoofyEvolution", - "InformalEvolution", - "ReasoningEvolution", - "UtteranceEvolver", + "AbstractEvolution", + "ConcreteEvolution", + "EvolutionChatTemplate", + "FormalEvolution", + "FunnyEvolution", + "GoofyEvolution", + "InformalEvolution", + "ReasoningEvolution", + "UtteranceEvolver", ] diff --git a/autointent/generation/utterances/evolution/chat_templates/concrete.py b/autointent/generation/utterances/evolution/chat_templates/concrete.py index 4a7ab52f2..dcca78bac 100644 --- a/autointent/generation/utterances/evolution/chat_templates/concrete.py +++ b/autointent/generation/utterances/evolution/chat_templates/concrete.py @@ -29,10 +29,7 @@ class ConcreteEvolution(EvolutionChatTemplate): Message(role=Role.ASSISTANT, content="I want to reserve a table for 4 persons at 9 pm."), Message( role=Role.USER, - content=( - "Intent name: requesting technical support\n" - "Utterance: I'm having trouble with my laptop." - ), + content=("Intent name: requesting technical support\n" "Utterance: I'm having trouble with my laptop."), ), Message(role=Role.ASSISTANT, content="My laptop is constantly rebooting and overheating."), ] diff --git a/autointent/generation/utterances/evolution/chat_templates/goofy.py b/autointent/generation/utterances/evolution/chat_templates/goofy.py index 15a6fcb17..c53156054 100644 --- a/autointent/generation/utterances/evolution/chat_templates/goofy.py +++ b/autointent/generation/utterances/evolution/chat_templates/goofy.py @@ -36,8 +36,7 @@ class GoofyEvolution(EvolutionChatTemplate): ), ), Message( - role=Role.ASSISTANT, - content="My laptop's having an existential crisis—keeps rebooting and melting. Help!" + role=Role.ASSISTANT, content="My laptop's having an existential crisis—keeps rebooting and melting. Help!" ), ] diff --git a/autointent/nodes/_optimization/_node_optimizer.py b/autointent/nodes/_optimization/_node_optimizer.py index 868403f9b..0b3420c10 100644 --- a/autointent/nodes/_optimization/_node_optimizer.py +++ b/autointent/nodes/_optimization/_node_optimizer.py @@ -9,6 +9,7 @@ import torch +from autointent import Dataset from autointent.context import Context from autointent.custom_types import NodeType from autointent.modules.abc import Module @@ -31,7 +32,7 @@ def __init__( :param node_type: Node type :param search_space: Search space for the optimization - :param metric: Metric to optimize. + :param metrics: Metrics to optimize. """ self.node_type = node_type self.node_info = NODES_INFO[node_type] @@ -41,7 +42,7 @@ def __init__( if self.target_metric not in self.metrics: self.metrics.append(self.target_metric) - self.modules_search_spaces = search_space # TODO search space validation + self.modules_search_spaces = search_space self._logger = logging.getLogger(__name__) # TODO solve duplicate logging messages problem def fit(self, context: Context) -> None: @@ -143,3 +144,25 @@ def module_fit(self, module: Module, context: Context) -> None: self._logger.error(msg) raise ValueError(msg) module.fit(*args) # type: ignore[arg-type] + + def validate_nodes_with_dataset(self, dataset: Dataset) -> None: + """ + Validate nodes with dataset. + + :param dataset: Dataset to use + """ + is_multilabel = dataset.multilabel + + for search_space in deepcopy(self.modules_search_spaces): + module_name = search_space.pop("module_name") + module = self.node_info.modules_available[module_name] + # todo add check for oos + + if is_multilabel and not module.supports_multilabel: + msg = f"Module '{module_name}' does not support multilabel datasets." + self._logger.error(msg) + raise ValueError(msg) + if not is_multilabel and not module.supports_multiclass: + msg = f"Module '{module_name}' does not support multiclass datasets." + self._logger.error(msg) + raise ValueError(msg) diff --git a/tests/pipeline/test_optimization.py b/tests/pipeline/test_optimization.py index 050eca742..9e53525f5 100644 --- a/tests/pipeline/test_optimization.py +++ b/tests/pipeline/test_optimization.py @@ -77,3 +77,32 @@ def test_dump_modules(dataset, task_type): context.dump() assert os.listdir(pipeline_optimizer.logging_config.dump_dir) + + +def test_validate_search_space_multiclass(dataset): + search_space = [ + { + "node_type": "decision", + "target_metric": "decision_accuracy", + "search_space": [{"module_name": "threshold", "thresh": [0.5]}, {"module_name": "adaptive"}], + }, + ] + + pipeline_optimizer = Pipeline.from_search_space(search_space) + with pytest.raises(ValueError, match="Module 'adaptive' does not support multiclass datasets."): + pipeline_optimizer.validate_modules(dataset) + + +def test_validate_search_space_multilabel(dataset): + dataset = dataset.to_multilabel() + + search_space = [ + { + "node_type": "decision", + "target_metric": "decision_accuracy", + "search_space": [{"module_name": "threshold", "thresh": [0.5]}, {"module_name": "argmax"}], + }, + ] + pipeline_optimizer = Pipeline.from_search_space(search_space) + with pytest.raises(ValueError, match="Module 'argmax' does not support multilabel datasets."): + pipeline_optimizer.validate_modules(dataset)