Skip to content

Commit 43b671c

Browse files
authored
update multilabel handling (#229)
* update multilabel handling * add autoconvert to multilabel when read * remove enum from split * remove useless conversions * format
1 parent 218db9a commit 43b671c

File tree

6 files changed

+47
-12
lines changed

6 files changed

+47
-12
lines changed

autointent/_dataset/_dataset.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,22 @@ def from_json(cls, filepath: str | Path) -> "Dataset":
9494
return JsonReader().read(filepath)
9595

9696
@classmethod
97-
def from_hub(cls, repo_name: str, intent_subset_name: str = Split.INTENTS) -> "Dataset":
97+
def from_hub(
98+
cls, repo_name: str, data_split: str = "default", intent_subset_name: str = Split.INTENTS
99+
) -> "Dataset":
98100
"""Loads a dataset from the Hugging Face Hub.
99101
100102
Args:
101103
repo_name: The name of the Hugging Face repository, like `DeepPavlov/clinc150`.
104+
data_split: The name of the dataset split to load, defaults to `default`.
102105
intent_subset_name: The name of the intent subset to load, defaults to `intents`.
103106
"""
104107
from ._reader import DictReader
105108

106-
splits = load_dataset(repo_name, "default")
109+
splits = load_dataset(repo_name, data_split)
107110
mapping = dict(**splits)
108111
if intent_subset_name in get_dataset_config_names(repo_name):
109-
mapping[Split.INTENTS] = load_dataset(repo_name, intent_subset_name, split=Split.INTENTS).to_list()
112+
mapping[Split.INTENTS] = load_dataset(repo_name, name=intent_subset_name, split=Split.INTENTS).to_list()
110113

111114
return DictReader().read(mapping)
112115

autointent/_pipeline/_pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
SearchSpacePreset,
2626
SearchSpaceValidationMode,
2727
)
28-
from autointent.metrics import DECISION_METRICS
28+
from autointent.metrics import DECISION_METRICS, DICISION_METRICS_MULTILABEL
2929
from autointent.nodes import InferenceNode, NodeOptimizer
3030
from autointent.utils import load_preset, load_search_space
3131

@@ -247,7 +247,8 @@ def fit(
247247

248248
if test_utterances is not None:
249249
predictions = self.predict(test_utterances)
250-
for metric_name, metric in DECISION_METRICS.items():
250+
metrics = DICISION_METRICS_MULTILABEL if context.data_handler.multilabel else DECISION_METRICS
251+
for metric_name, metric in metrics.items():
251252
context.optimization_info.pipeline_metrics[metric_name] = metric(
252253
context.data_handler.test_labels(),
253254
predictions,

autointent/metrics/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,14 @@
7373
scoring_roc_auc,
7474
)
7575

76-
SCORING_METRICS_MULTILABEL: dict[str, ScoringMetricFn] = SCORING_METRICS_MULTICLASS | _funcs_to_dict(
76+
SCORING_METRICS_MULTILABEL: dict[str, ScoringMetricFn] = _funcs_to_dict(
77+
# multiclass except for scoring_roc_auc
78+
scoring_accuracy,
79+
scoring_f1,
80+
scoring_log_likelihood,
81+
scoring_precision,
82+
scoring_recall,
83+
# multilabel
7784
scoring_hit_rate,
7885
scoring_map,
7986
scoring_neg_coverage,
@@ -88,6 +95,13 @@
8895
decision_roc_auc,
8996
)
9097

98+
DICISION_METRICS_MULTILABEL: dict[str, DecisionMetricFn] = _funcs_to_dict(
99+
decision_accuracy,
100+
decision_f1,
101+
decision_precision,
102+
decision_recall,
103+
)
104+
91105
REGEX_METRICS = _funcs_to_dict(regex_partial_accuracy, regex_partial_precision)
92106

93107
METRIC_FN = DecisionMetricFn | RegexMetricFn | RetrievalMetricFn | ScoringMetricFn

autointent/metrics/decision.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def decision_roc_auc(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -
180180
if y_pred_.ndim == y_true_.ndim == 1:
181181
return _decision_roc_auc_multiclass(y_true_, y_pred_)
182182
if y_pred_.ndim == y_true_.ndim == 2: # noqa: PLR2004
183+
# not working with 1 class in y_true
183184
return _decision_roc_auc_multilabel(y_true_, y_pred_)
184185
msg = "Something went wrong with labels dimensions"
185186
logger.error(msg)

autointent/nodes/_node_optimizer.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,18 +237,34 @@ def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidat
237237
filtered_search_space = []
238238
if is_multilabel and self.target_metric not in self.node_info.multilabel_available_metrics:
239239
handle_message_on_mode(
240-
mode, f"Target metric '{self.target_metric}' is not available for multilabel datasets.", True
240+
mode,
241+
f"Target metric '{self.target_metric}' is not available for multilabel datasets. "
242+
f"Available metrics: {list(self.node_info.multilabel_available_metrics.keys())}",
243+
True,
241244
)
242245
elif not is_multilabel and self.target_metric not in self.node_info.multiclass_available_metrics:
243246
handle_message_on_mode(
244-
mode, f"Target metric '{self.target_metric}' is not available for multiclass datasets.", True
247+
mode,
248+
f"Target metric '{self.target_metric}' is not available for multiclass datasets. "
249+
f"Available metrics: {list(self.node_info.multiclass_available_metrics.keys())}",
250+
True,
245251
)
246252

247253
for metric in self.metrics:
248254
if is_multilabel and metric not in self.node_info.multilabel_available_metrics:
249-
handle_message_on_mode(mode, f"Metric '{metric}' is not available for multilabel datasets.", True)
255+
handle_message_on_mode(
256+
mode,
257+
f"Metric '{metric}' is not available for multilabel datasets. "
258+
f"Available metrics: {list(self.node_info.multilabel_available_metrics.keys())}",
259+
True,
260+
)
250261
elif not is_multilabel and metric not in self.node_info.multiclass_available_metrics:
251-
handle_message_on_mode(mode, f"Metric '{metric}' is not available for multiclass datasets.", True)
262+
handle_message_on_mode(
263+
mode,
264+
f"Metric '{metric}' is not available for multiclass datasets. "
265+
f"Available metrics: {list(self.node_info.multiclass_available_metrics.keys())}",
266+
True,
267+
)
252268

253269
for search_space in deepcopy(self.modules_search_spaces):
254270
module_name = search_space["module_name"]

autointent/nodes/info/_decision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import ClassVar
55

66
from autointent.custom_types import NodeType
7-
from autointent.metrics import DECISION_METRICS, DecisionMetricFn
7+
from autointent.metrics import DECISION_METRICS, DICISION_METRICS_MULTILABEL, DecisionMetricFn
88
from autointent.modules import DECISION_MODULES
99
from autointent.modules.base import BaseDecision
1010

@@ -22,4 +22,4 @@ class DecisionNodeInfo(NodeInfo):
2222

2323
multiclass_available_metrics = DECISION_METRICS
2424

25-
multilabel_available_metrics = DECISION_METRICS
25+
multilabel_available_metrics = DICISION_METRICS_MULTILABEL

0 commit comments

Comments
 (0)