Skip to content

Commit 6a478cd

Browse files
authored
add validation with dataset (#118)
1 parent 8b8abef commit 6a478cd

File tree

8 files changed

+106
-47
lines changed

8 files changed

+106
-47
lines changed

autointent/_dataset/_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def push_to_hub(self, repo_id: str, private: bool = False) -> None:
146146
Push dataset splits to a Hugging Face repository.
147147
148148
:param repo_id: ID of the Hugging Face repository.
149+
:param private: Whether the repository is private
149150
"""
150151
for split_name, split in self.items():
151152
split.push_to_hub(repo_id, split=split_name, private=private)

autointent/_pipeline/_pipeline.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def fit(self, dataset: Dataset) -> Context:
138138
context.configure_logging(self.logging_config)
139139
context.configure_vector_index(self.vector_index_config, self.embedder_config)
140140
context.configure_cross_encoder(self.cross_encoder_config)
141-
141+
self.validate_modules(dataset)
142142
self._fit(context)
143143

144144
if context.is_ram_to_clear():
@@ -160,6 +160,16 @@ def fit(self, dataset: Dataset) -> Context:
160160

161161
return context
162162

163+
def validate_modules(self, dataset: Dataset) -> None:
164+
"""
165+
Validate modules with dataset.
166+
167+
:param dataset: dataset to validate with
168+
"""
169+
for node in self.nodes.values():
170+
if isinstance(node, NodeOptimizer):
171+
node.validate_nodes_with_dataset(dataset)
172+
163173
@classmethod
164174
def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> "Pipeline":
165175
"""
Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
from .basic import SynthesizerChatTemplate, UtteranceGenerator
22
from .evolution import (
3-
AbstractEvolution,
4-
ConcreteEvolution,
5-
EvolutionChatTemplate,
6-
FormalEvolution,
7-
FunnyEvolution,
8-
GoofyEvolution,
9-
InformalEvolution,
10-
ReasoningEvolution,
11-
UtteranceEvolver,
3+
AbstractEvolution,
4+
ConcreteEvolution,
5+
EvolutionChatTemplate,
6+
FormalEvolution,
7+
FunnyEvolution,
8+
GoofyEvolution,
9+
InformalEvolution,
10+
ReasoningEvolution,
11+
UtteranceEvolver,
1212
)
1313
from .generator import Generator
1414

1515
__all__ = [
16-
"AbstractEvolution",
17-
"ConcreteEvolution",
18-
"EvolutionChatTemplate",
19-
"FormalEvolution",
20-
"FunnyEvolution",
21-
"Generator",
22-
"GoofyEvolution",
23-
"InformalEvolution",
24-
"ReasoningEvolution",
25-
"SynthesizerChatTemplate",
26-
"UtteranceEvolver",
27-
"UtteranceGenerator",
16+
"AbstractEvolution",
17+
"ConcreteEvolution",
18+
"EvolutionChatTemplate",
19+
"FormalEvolution",
20+
"FunnyEvolution",
21+
"Generator",
22+
"GoofyEvolution",
23+
"InformalEvolution",
24+
"ReasoningEvolution",
25+
"SynthesizerChatTemplate",
26+
"UtteranceEvolver",
27+
"UtteranceGenerator",
2828
]
Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
from .chat_templates import (
2-
AbstractEvolution,
3-
ConcreteEvolution,
4-
EvolutionChatTemplate,
5-
FormalEvolution,
6-
FunnyEvolution,
7-
GoofyEvolution,
8-
InformalEvolution,
9-
ReasoningEvolution,
2+
AbstractEvolution,
3+
ConcreteEvolution,
4+
EvolutionChatTemplate,
5+
FormalEvolution,
6+
FunnyEvolution,
7+
GoofyEvolution,
8+
InformalEvolution,
9+
ReasoningEvolution,
1010
)
1111
from .evolver import UtteranceEvolver
1212

1313
__all__ = [
14-
"AbstractEvolution",
15-
"ConcreteEvolution",
16-
"EvolutionChatTemplate",
17-
"FormalEvolution",
18-
"FunnyEvolution",
19-
"GoofyEvolution",
20-
"InformalEvolution",
21-
"ReasoningEvolution",
22-
"UtteranceEvolver",
14+
"AbstractEvolution",
15+
"ConcreteEvolution",
16+
"EvolutionChatTemplate",
17+
"FormalEvolution",
18+
"FunnyEvolution",
19+
"GoofyEvolution",
20+
"InformalEvolution",
21+
"ReasoningEvolution",
22+
"UtteranceEvolver",
2323
]

autointent/generation/utterances/evolution/chat_templates/concrete.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@ class ConcreteEvolution(EvolutionChatTemplate):
2929
Message(role=Role.ASSISTANT, content="I want to reserve a table for 4 persons at 9 pm."),
3030
Message(
3131
role=Role.USER,
32-
content=(
33-
"Intent name: requesting technical support\n"
34-
"Utterance: I'm having trouble with my laptop."
35-
),
32+
content=("Intent name: requesting technical support\n" "Utterance: I'm having trouble with my laptop."),
3633
),
3734
Message(role=Role.ASSISTANT, content="My laptop is constantly rebooting and overheating."),
3835
]

autointent/generation/utterances/evolution/chat_templates/goofy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ class GoofyEvolution(EvolutionChatTemplate):
3636
),
3737
),
3838
Message(
39-
role=Role.ASSISTANT,
40-
content="My laptop's having an existential crisis—keeps rebooting and melting. Help!"
39+
role=Role.ASSISTANT, content="My laptop's having an existential crisis—keeps rebooting and melting. Help!"
4140
),
4241
]
4342

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111

12+
from autointent import Dataset
1213
from autointent.context import Context
1314
from autointent.custom_types import NodeType
1415
from autointent.modules.abc import Module
@@ -31,7 +32,7 @@ def __init__(
3132
3233
:param node_type: Node type
3334
:param search_space: Search space for the optimization
34-
:param metric: Metric to optimize.
35+
:param metrics: Metrics to optimize.
3536
"""
3637
self.node_type = node_type
3738
self.node_info = NODES_INFO[node_type]
@@ -41,7 +42,7 @@ def __init__(
4142
if self.target_metric not in self.metrics:
4243
self.metrics.append(self.target_metric)
4344

44-
self.modules_search_spaces = search_space # TODO search space validation
45+
self.modules_search_spaces = search_space
4546
self._logger = logging.getLogger(__name__) # TODO solve duplicate logging messages problem
4647

4748
def fit(self, context: Context) -> None:
@@ -143,3 +144,25 @@ def module_fit(self, module: Module, context: Context) -> None:
143144
self._logger.error(msg)
144145
raise ValueError(msg)
145146
module.fit(*args) # type: ignore[arg-type]
147+
148+
def validate_nodes_with_dataset(self, dataset: Dataset) -> None:
149+
"""
150+
Validate nodes with dataset.
151+
152+
:param dataset: Dataset to use
153+
"""
154+
is_multilabel = dataset.multilabel
155+
156+
for search_space in deepcopy(self.modules_search_spaces):
157+
module_name = search_space.pop("module_name")
158+
module = self.node_info.modules_available[module_name]
159+
# todo add check for oos
160+
161+
if is_multilabel and not module.supports_multilabel:
162+
msg = f"Module '{module_name}' does not support multilabel datasets."
163+
self._logger.error(msg)
164+
raise ValueError(msg)
165+
if not is_multilabel and not module.supports_multiclass:
166+
msg = f"Module '{module_name}' does not support multiclass datasets."
167+
self._logger.error(msg)
168+
raise ValueError(msg)

tests/pipeline/test_optimization.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,32 @@ def test_dump_modules(dataset, task_type):
7777
context.dump()
7878

7979
assert os.listdir(pipeline_optimizer.logging_config.dump_dir)
80+
81+
82+
def test_validate_search_space_multiclass(dataset):
83+
search_space = [
84+
{
85+
"node_type": "decision",
86+
"target_metric": "decision_accuracy",
87+
"search_space": [{"module_name": "threshold", "thresh": [0.5]}, {"module_name": "adaptive"}],
88+
},
89+
]
90+
91+
pipeline_optimizer = Pipeline.from_search_space(search_space)
92+
with pytest.raises(ValueError, match="Module 'adaptive' does not support multiclass datasets."):
93+
pipeline_optimizer.validate_modules(dataset)
94+
95+
96+
def test_validate_search_space_multilabel(dataset):
97+
dataset = dataset.to_multilabel()
98+
99+
search_space = [
100+
{
101+
"node_type": "decision",
102+
"target_metric": "decision_accuracy",
103+
"search_space": [{"module_name": "threshold", "thresh": [0.5]}, {"module_name": "argmax"}],
104+
},
105+
]
106+
pipeline_optimizer = Pipeline.from_search_space(search_space)
107+
with pytest.raises(ValueError, match="Module 'argmax' does not support multilabel datasets."):
108+
pipeline_optimizer.validate_modules(dataset)

0 commit comments

Comments
 (0)