Skip to content

Commit d61ee10

Browse files
authored
Check if metric can handle dataset type (#224)
* add test for configuration * lint * satisfy mypy
1 parent 4dcd54e commit d61ee10

File tree

8 files changed

+80
-10
lines changed

8 files changed

+80
-10
lines changed

autointent/nodes/_node_optimizer.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def get_module_dump_dir(self, context: Context, module_name: str, j_combination:
241241
dump_dir_.mkdir(parents=True, exist_ok=True)
242242
return str(dump_dir_)
243243

244-
def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidationMode) -> None:
244+
def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidationMode) -> None: # noqa: C901
245245
"""Validates nodes against the dataset.
246246
247247
Args:
@@ -254,12 +254,24 @@ def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidat
254254
is_multilabel = dataset.multilabel
255255

256256
filtered_search_space = []
257+
if is_multilabel and self.target_metric not in self.node_info.multilabel_available_metrics:
258+
handle_message_on_mode(
259+
mode, f"Target metric '{self.target_metric}' is not available for multilabel datasets.", True
260+
)
261+
elif not is_multilabel and self.target_metric not in self.node_info.multiclass_available_metrics:
262+
handle_message_on_mode(
263+
mode, f"Target metric '{self.target_metric}' is not available for multiclass datasets.", True
264+
)
265+
266+
for metric in self.metrics:
267+
if is_multilabel and metric not in self.node_info.multilabel_available_metrics:
268+
handle_message_on_mode(mode, f"Metric '{metric}' is not available for multilabel datasets.", True)
269+
elif not is_multilabel and metric not in self.node_info.multiclass_available_metrics:
270+
handle_message_on_mode(mode, f"Metric '{metric}' is not available for multiclass datasets.", True)
257271

258272
for search_space in deepcopy(self.modules_search_spaces):
259273
module_name = search_space["module_name"]
260274
module = self.node_info.modules_available[module_name]
261-
# todo add check for oos
262-
263275
messages = []
264276

265277
if module_name == "description" and not dataset.has_descriptions:
@@ -273,11 +285,7 @@ def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidat
273285

274286
if len(messages) > 0:
275287
msg = "\n".join(messages)
276-
if mode == "raise":
277-
self._logger.error(msg)
278-
raise ValueError(msg)
279-
if mode == "warning":
280-
self._logger.warning(msg)
288+
handle_message_on_mode(mode, msg)
281289
else:
282290
filtered_search_space.append(search_space)
283291

@@ -393,3 +401,26 @@ def load_or_create_study(
393401
finished_trials,
394402
remaining_trials,
395403
)
404+
405+
406+
def handle_message_on_mode(
407+
mode: SearchSpaceValidationMode,
408+
message: str,
409+
strict: bool = False,
410+
) -> None:
411+
"""Handle messages based on the validation mode.
412+
413+
Args:
414+
mode: The validation mode ("raise" or "warning").
415+
message: The message to handle.
416+
strict: If True always raises an error, even if mode is "warning".
417+
418+
Raises:
419+
ValueError: If mode is "raise".
420+
"""
421+
if mode == "raise":
422+
raise ValueError(message)
423+
if mode == "warning":
424+
logger.warning(message)
425+
if strict:
426+
raise ValueError(message)

autointent/nodes/info/_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@ class NodeInfo:
1717
"""Available modules for the node."""
1818
node_type: NodeType
1919
"""Node type."""
20+
multiclass_available_metrics: ClassVar[Mapping[str, METRIC_FN]]
21+
"""Available metrics for multiclass classification."""
22+
multilabel_available_metrics: ClassVar[Mapping[str, METRIC_FN]]
23+
"""Available metrics for multilabel classification."""

autointent/nodes/info/_decision.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@ class DecisionNodeInfo(NodeInfo):
1919
modules_available: ClassVar[dict[str, type[BaseDecision]]] = DECISION_MODULES
2020

2121
node_type = NodeType.decision
22+
23+
multiclass_available_metrics = DECISION_METRICS
24+
25+
multilabel_available_metrics = DECISION_METRICS

autointent/nodes/info/_embedding.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Retrieval node info."""
22

33
from collections.abc import Mapping
4-
from typing import ClassVar
4+
from typing import ClassVar, cast
55

66
from autointent.custom_types import NodeType
77
from autointent.metrics import (
@@ -31,3 +31,11 @@ class EmbeddingNodeInfo(NodeInfo):
3131
modules_available: ClassVar[Mapping[str, type[BaseEmbedding]]] = EMBEDDING_MODULES
3232

3333
node_type = NodeType.embedding
34+
35+
multiclass_available_metrics: ClassVar[Mapping[str, RetrievalMetricFn | ScoringMetricFn]] = cast(
36+
Mapping[str, RetrievalMetricFn | ScoringMetricFn], RETRIEVAL_METRICS_MULTICLASS | SCORING_METRICS_MULTICLASS
37+
)
38+
39+
multilabel_available_metrics: ClassVar[Mapping[str, RetrievalMetricFn | ScoringMetricFn]] = cast(
40+
Mapping[str, RetrievalMetricFn | ScoringMetricFn], RETRIEVAL_METRICS_MULTILABEL | SCORING_METRICS_MULTILABEL
41+
)

autointent/nodes/info/_regex.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ class RegexNodeInfo(NodeInfo):
2020
modules_available: ClassVar[Mapping[str, type[BaseRegex]]] = REGEX_MODULES
2121

2222
node_type = NodeType.regex
23+
24+
multiclass_available_metrics: ClassVar[Mapping[str, RegexMetricFn]] = REGEX_METRICS
25+
26+
multilabel_available_metrics: ClassVar[Mapping[str, RegexMetricFn]] = REGEX_METRICS

autointent/nodes/info/_scoring.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ class ScoringNodeInfo(NodeInfo):
1919
modules_available: ClassVar[Mapping[str, type[BaseScorer]]] = SCORING_MODULES
2020

2121
node_type = NodeType.scoring
22+
23+
multiclass_available_metrics: ClassVar[Mapping[str, ScoringMetricFn]] = SCORING_METRICS_MULTICLASS
24+
multilabel_available_metrics: ClassVar[Mapping[str, ScoringMetricFn]] = SCORING_METRICS_MULTILABEL

tests/assets/configs/multilabel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
- model_name: sentence-transformers/all-MiniLM-L6-v2
88
- model_name: avsolatorio/GIST-small-Embedding-v0
99
- node_type: scoring
10-
target_metric: scoring_roc_auc
10+
target_metric: scoring_hit_rate
1111
search_space:
1212
- module_name: knn
1313
k: [5, 10]

tests/pipeline/test_optimization.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,19 @@ def test_dump_modules(dataset, task_type):
128128
context.dump()
129129

130130
assert os.listdir(pipeline_optimizer.logging_config.dump_dir)
131+
132+
133+
@pytest.mark.parametrize(
134+
"task_type",
135+
["multiclass", "multilabel"],
136+
)
137+
def test_optimization_validation_metric_names(dataset, task_type):
138+
search_space = get_search_space(task_type)
139+
140+
pipeline_optimizer = Pipeline.from_search_space(search_space)
141+
142+
if task_type == "multiclass":
143+
dataset = dataset.to_multilabel()
144+
145+
with pytest.raises(ValueError, match="Target metric .*"):
146+
pipeline_optimizer.fit(dataset, incompatible_search_space="raise")

0 commit comments

Comments
 (0)