Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autointent/_dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def push_to_hub(self, repo_id: str, private: bool = False) -> None:
Push dataset splits to a Hugging Face repository.

:param repo_id: ID of the Hugging Face repository.
:param private: Whether the repository is private
"""
for split_name, split in self.items():
split.push_to_hub(repo_id, split=split_name, private=private)
Expand Down
12 changes: 11 additions & 1 deletion autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def fit(self, dataset: Dataset) -> Context:
context.configure_logging(self.logging_config)
context.configure_vector_index(self.vector_index_config, self.embedder_config)
context.configure_cross_encoder(self.cross_encoder_config)

self.validate_modules(dataset)
self._fit(context)

if context.is_ram_to_clear():
Expand All @@ -160,6 +160,16 @@ def fit(self, dataset: Dataset) -> Context:

return context

def validate_modules(self, dataset: Dataset) -> None:
"""
Validate modules with dataset.

:param dataset: dataset to validate with
"""
for node in self.nodes.values():
if isinstance(node, NodeOptimizer):
node.validate_nodes_with_dataset(dataset)

@classmethod
def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> "Pipeline":
"""
Expand Down
42 changes: 21 additions & 21 deletions autointent/generation/utterances/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from .basic import SynthesizerChatTemplate, UtteranceGenerator
from .evolution import (
AbstractEvolution,
ConcreteEvolution,
EvolutionChatTemplate,
FormalEvolution,
FunnyEvolution,
GoofyEvolution,
InformalEvolution,
ReasoningEvolution,
UtteranceEvolver,
AbstractEvolution,
ConcreteEvolution,
EvolutionChatTemplate,
FormalEvolution,
FunnyEvolution,
GoofyEvolution,
InformalEvolution,
ReasoningEvolution,
UtteranceEvolver,
)
from .generator import Generator

__all__ = [
"AbstractEvolution",
"ConcreteEvolution",
"EvolutionChatTemplate",
"FormalEvolution",
"FunnyEvolution",
"Generator",
"GoofyEvolution",
"InformalEvolution",
"ReasoningEvolution",
"SynthesizerChatTemplate",
"UtteranceEvolver",
"UtteranceGenerator",
"AbstractEvolution",
"ConcreteEvolution",
"EvolutionChatTemplate",
"FormalEvolution",
"FunnyEvolution",
"Generator",
"GoofyEvolution",
"InformalEvolution",
"ReasoningEvolution",
"SynthesizerChatTemplate",
"UtteranceEvolver",
"UtteranceGenerator",
]
34 changes: 17 additions & 17 deletions autointent/generation/utterances/evolution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
from .chat_templates import (
AbstractEvolution,
ConcreteEvolution,
EvolutionChatTemplate,
FormalEvolution,
FunnyEvolution,
GoofyEvolution,
InformalEvolution,
ReasoningEvolution,
AbstractEvolution,
ConcreteEvolution,
EvolutionChatTemplate,
FormalEvolution,
FunnyEvolution,
GoofyEvolution,
InformalEvolution,
ReasoningEvolution,
)
from .evolver import UtteranceEvolver

__all__ = [
"AbstractEvolution",
"ConcreteEvolution",
"EvolutionChatTemplate",
"FormalEvolution",
"FunnyEvolution",
"GoofyEvolution",
"InformalEvolution",
"ReasoningEvolution",
"UtteranceEvolver",
"AbstractEvolution",
"ConcreteEvolution",
"EvolutionChatTemplate",
"FormalEvolution",
"FunnyEvolution",
"GoofyEvolution",
"InformalEvolution",
"ReasoningEvolution",
"UtteranceEvolver",
]
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ class ConcreteEvolution(EvolutionChatTemplate):
Message(role=Role.ASSISTANT, content="I want to reserve a table for 4 persons at 9 pm."),
Message(
role=Role.USER,
content=(
"Intent name: requesting technical support\n"
"Utterance: I'm having trouble with my laptop."
),
content=("Intent name: requesting technical support\n" "Utterance: I'm having trouble with my laptop."),
),
Message(role=Role.ASSISTANT, content="My laptop is constantly rebooting and overheating."),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ class GoofyEvolution(EvolutionChatTemplate):
),
),
Message(
role=Role.ASSISTANT,
content="My laptop's having an existential crisis—keeps rebooting and melting. Help!"
role=Role.ASSISTANT, content="My laptop's having an existential crisis—keeps rebooting and melting. Help!"
),
]

Expand Down
27 changes: 25 additions & 2 deletions autointent/nodes/_optimization/_node_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch

from autointent import Dataset
from autointent.context import Context
from autointent.custom_types import NodeType
from autointent.modules.abc import Module
Expand All @@ -31,7 +32,7 @@ def __init__(

:param node_type: Node type
:param search_space: Search space for the optimization
:param metric: Metric to optimize.
:param metrics: Metrics to optimize.
"""
self.node_type = node_type
self.node_info = NODES_INFO[node_type]
Expand All @@ -41,7 +42,7 @@ def __init__(
if self.target_metric not in self.metrics:
self.metrics.append(self.target_metric)

self.modules_search_spaces = search_space # TODO search space validation
self.modules_search_spaces = search_space
self._logger = logging.getLogger(__name__) # TODO solve duplicate logging messages problem

def fit(self, context: Context) -> None:
Expand Down Expand Up @@ -143,3 +144,25 @@ def module_fit(self, module: Module, context: Context) -> None:
self._logger.error(msg)
raise ValueError(msg)
module.fit(*args) # type: ignore[arg-type]

def validate_nodes_with_dataset(self, dataset: Dataset) -> None:
"""
Validate nodes with dataset.

:param dataset: Dataset to use
"""
is_multilabel = dataset.multilabel

for search_space in deepcopy(self.modules_search_spaces):
module_name = search_space.pop("module_name")
module = self.node_info.modules_available[module_name]
# todo add check for oos

if is_multilabel and not module.supports_multilabel:
msg = f"Module '{module_name}' does not support multilabel datasets."
self._logger.error(msg)
raise ValueError(msg)
if not is_multilabel and not module.supports_multiclass:
msg = f"Module '{module_name}' does not support multiclass datasets."
self._logger.error(msg)
raise ValueError(msg)
29 changes: 29 additions & 0 deletions tests/pipeline/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,32 @@ def test_dump_modules(dataset, task_type):
context.dump()

assert os.listdir(pipeline_optimizer.logging_config.dump_dir)


def test_validate_search_space_multiclass(dataset):
search_space = [
{
"node_type": "decision",
"target_metric": "decision_accuracy",
"search_space": [{"module_name": "threshold", "thresh": [0.5]}, {"module_name": "adaptive"}],
},
]

pipeline_optimizer = Pipeline.from_search_space(search_space)
with pytest.raises(ValueError, match="Module 'adaptive' does not support multiclass datasets."):
pipeline_optimizer.validate_modules(dataset)


def test_validate_search_space_multilabel(dataset):
dataset = dataset.to_multilabel()

search_space = [
{
"node_type": "decision",
"target_metric": "decision_accuracy",
"search_space": [{"module_name": "threshold", "thresh": [0.5]}, {"module_name": "argmax"}],
},
]
pipeline_optimizer = Pipeline.from_search_space(search_space)
with pytest.raises(ValueError, match="Module 'argmax' does not support multilabel datasets."):
pipeline_optimizer.validate_modules(dataset)