Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions autointent/_datafiles/default-multiclass-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# TODO: make up a better and more versatile config
- node_type: embedding
metric: retrieval_hit_rate
target_metric: retrieval_hit_rate
search_space:
- module_name: retrieval
k: [10]
embedder_name:
- avsolatorio/GIST-small-Embedding-v0
- infgrad/stella-base-en-v2
- node_type: scoring
metric: scoring_roc_auc
target_metric: scoring_roc_auc
search_space:
- module_name: knn
k: [1, 3, 5, 10]
Expand All @@ -20,8 +20,8 @@
- cross-encoder/ms-marco-MiniLM-L-6-v2
k: [1, 3, 5, 10]
- node_type: decision
metric: decision_accuracy
target_metric: decision_accuracy
search_space:
- module_name: threshold
thresh: [0.5]
- module_name: argmax
- module_name: argmax
8 changes: 4 additions & 4 deletions autointent/_datafiles/default-multilabel-config.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# TODO: make up a better and more versatile config
- node_type: embedding
metric: retrieval_hit_rate_intersecting
target_metric: retrieval_hit_rate_intersecting
search_space:
- module_name: retrieval
k: [10]
embedder_name:
- deepvk/USER-bge-m3
- node_type: scoring
metric: scoring_roc_auc
target_metric: scoring_roc_auc
search_space:
- module_name: knn
k: [3]
weights: ["uniform", "distance", "closest"]
- module_name: linear
- node_type: decision
metric: decision_accuracy
target_metric: decision_accuracy
search_space:
- module_name: threshold
thresh: [0.5]
- module_name: adaptive
- module_name: adaptive
8 changes: 2 additions & 6 deletions autointent/modules/abc/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ def fit(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None:
"""

@abstractmethod
def score(
self,
context: Context,
split: Literal["validation", "test"],
) -> dict[str, float | str]:
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
"""
Calculate metric on test set and return metric value.

Expand Down Expand Up @@ -110,7 +106,7 @@ def get_embedder_name(self) -> str | None:
return None

@staticmethod
def score_metrics(params: tuple[Any, Any], metrics_dict: dict[str, Any]) -> dict[str, float | str]:
def score_metrics(params: tuple[Any, Any], metrics_dict: dict[str, Any]) -> dict[str, float]:
"""
Score metrics on the test set.

Expand Down
9 changes: 3 additions & 6 deletions autointent/modules/abc/_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@ def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels:
:param scores: Scores to predict
"""

def score(
self,
context: Context,
split: Literal["validation", "test"],
) -> dict[str, float | str]:
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Докстринги нигде не обновлены

"""
Calculate metric on test set and return metric value.

Expand All @@ -54,7 +50,8 @@ def score(
"""
labels, scores = get_decision_evaluation_data(context, split)
self._decisions = self.predict(scores)
return self.score_metrics((labels, self._decisions), PREDICTION_METRICS_MULTICLASS)
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics}
return self.score_metrics((labels, self._decisions), chosen_metrics)

def get_assets(self) -> DecisionArtifact:
"""Return useful assets that represent intermediate data into context."""
Expand Down
9 changes: 3 additions & 6 deletions autointent/modules/abc/_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ class ScoringModule(Module, ABC):

supports_oos = False

def score(
self,
context: Context,
split: Literal["validation", "test"],
) -> dict[str, float | str]:
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
"""
Evaluate the scorer on a test set and compute the specified metric.

Expand All @@ -50,7 +46,8 @@ def score(
self._test_scores = self.predict(context.data_handler.test_utterances())

metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
return self.score_metrics((labels, scores), metrics_dict)
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
return self.score_metrics((labels, scores), chosen_metrics)

def get_assets(self) -> ScorerArtifact:
"""
Expand Down
9 changes: 3 additions & 6 deletions autointent/modules/embedding/_logreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:

self._classifier.fit(embeddings, labels)

def score(
self,
context: Context,
split: Literal["validation", "test"],
) -> dict[str, float | str]:
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
"""
Evaluate the embedding model using a specified metric function.

Expand All @@ -153,7 +149,8 @@ def score(

probas = self.predict(utterances)
metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
return self.score_metrics((labels, probas), metrics_dict)
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
return self.score_metrics((labels, probas), chosen_metrics)

def get_assets(self) -> RetrieverArtifact:
"""
Expand Down
9 changes: 3 additions & 6 deletions autointent/modules/embedding/_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
)
self._vector_index.add(utterances, labels)

def score(
self,
context: Context,
split: Literal["validation", "test"],
) -> dict[str, float | str]:
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
"""
Evaluate the embedding model using a specified metric function.

Expand All @@ -133,7 +129,8 @@ def score(
predictions, _, _ = self._vector_index.query(utterances, self.k)

metrics_dict = RETRIEVAL_METRICS_MULTILABEL if context.is_multilabel() else RETRIEVAL_METRICS_MULTICLASS
return self.score_metrics((labels, predictions), metrics_dict)
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
return self.score_metrics((labels, predictions), chosen_metrics)

def get_assets(self) -> RetrieverArtifact:
"""
Expand Down
9 changes: 3 additions & 6 deletions autointent/modules/regexp/_regexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,7 @@ def _predict_single(self, utterance: str) -> tuple[LabelType, dict[str, list[str
matches["partial_matches"].extend(intent_matches["partial_matches"])
return list(prediction), matches

def score(
self,
context: Context,
split: Literal["validation", "test"],
) -> dict[str, float | str]:
def score(self, context: Context, split: Literal["validation", "test"], metrics: list[str]) -> dict[str, float]:
"""
Calculate metric on test set and return metric value.

Expand All @@ -128,7 +124,8 @@ def score(
if assets["test_matches"] is None:
msg = "no matches found"
raise ValueError(msg)
return self.score_metrics((context.data_handler.test_labels(), assets["test_matches"]), REGEXP_METRICS)
chosen_metrics = {name: fn for name, fn in REGEXP_METRICS.items() if name in metrics}
return self.score_metrics((context.data_handler.test_labels(), assets["test_matches"]), chosen_metrics)

def clear_cache(self) -> None:
"""Clear cache."""
Expand Down
27 changes: 17 additions & 10 deletions autointent/nodes/_optimization/_node_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
class NodeOptimizer:
"""Node optimizer class."""

def __init__(self, node_type: NodeType, search_space: list[dict[str, Any]], metric: str) -> None:
def __init__(
self,
node_type: NodeType,
search_space: list[dict[str, Any]],
target_metric: str,
metrics: list[str] | None = None,
) -> None:
"""
Initialize the node optimizer.

Expand All @@ -29,7 +35,12 @@ def __init__(self, node_type: NodeType, search_space: list[dict[str, Any]], metr
"""
self.node_type = node_type
self.node_info = NODES_INFO[node_type]
self.metric_name = metric
self.decision_metric_name = target_metric
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

А почему у нас тут остался все равно decision_metric_name?


self.metrics = metrics if metrics is not None else []
if self.decision_metric_name not in self.metrics:
Copy link
Member

@Samoed Samoed Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

А если он не в наших метриках то что происходит?

self.metrics.append(self.decision_metric_name)

self.modules_search_spaces = search_space # TODO search space validation
self._logger = logging.getLogger(__name__) # TODO solve duplicate logging messages problem

Expand Down Expand Up @@ -61,14 +72,10 @@ def fit(self, context: Context) -> None:
self.module_fit(module, context)

self._logger.debug("scoring %s module...", module_name)
metrics = module.score(context, "validation")
metric_value = metrics[self.metric_name]

# some metrics can produce error. When main metric produces error raise it.
if isinstance(metric_value, str):
raise Exception(metric_value) # noqa: TRY004, TRY002
metrics_score = module.score(context, "validation", self.metrics)
metric_value = metrics_score[self.decision_metric_name]

context.callback_handler.log_metrics(metrics)
context.callback_handler.log_metrics(metrics_score)
context.callback_handler.end_module()

dump_dir = context.get_dump_dir()
Expand All @@ -84,7 +91,7 @@ def fit(self, context: Context) -> None:
module_name,
module_kwargs,
metric_value,
self.metric_name,
self.decision_metric_name,
module.get_assets(), # retriever name / scores / predictions
module_dump_dir,
module=module if not context.is_ram_to_clear() else None,
Expand Down
6 changes: 3 additions & 3 deletions tests/assets/configs/description.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
- node_type: embedding
metric: retrieval_hit_rate
target_metric: retrieval_hit_rate
search_space:
- module_name: retrieval
k: [10]
embedder_name:
- sentence-transformers/all-MiniLM-L6-v2
- node_type: scoring
metric: scoring_roc_auc
target_metric: scoring_roc_auc
search_space:
- module_name: description
temperature: [1.0, 0.5, 0.1, 0.05]
- node_type: decision
metric: decision_accuracy
target_metric: decision_accuracy
search_space:
- module_name: argmax
8 changes: 4 additions & 4 deletions tests/assets/configs/multiclass.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
- node_type: embedding
metric: retrieval_hit_rate
target_metric: retrieval_hit_rate
search_space:
- module_name: retrieval
k: [10]
embedder_name:
- sentence-transformers/all-MiniLM-L6-v2
- avsolatorio/GIST-small-Embedding-v0
- node_type: scoring
metric: scoring_roc_auc
target_metric: scoring_roc_auc
search_space:
- module_name: knn
k: [5, 10]
Expand All @@ -32,10 +32,10 @@
cross_encoder_name:
- cross-encoder/ms-marco-MiniLM-L-6-v2
- node_type: decision
metric: decision_accuracy
target_metric: decision_accuracy
search_space:
- module_name: threshold
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]
- module_name: tunable
- module_name: argmax
- module_name: jinoos
- module_name: jinoos
6 changes: 3 additions & 3 deletions tests/assets/configs/multilabel.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
- node_type: embedding
metric: scoring_accuracy
target_metric: scoring_accuracy
search_space:
- module_name: logreg
cv: [2]
embedder_name:
- sentence-transformers/all-MiniLM-L6-v2
- avsolatorio/GIST-small-Embedding-v0
- node_type: scoring
metric: scoring_roc_auc
target_metric: scoring_roc_auc
search_space:
- module_name: knn
k: [5, 10]
Expand All @@ -28,7 +28,7 @@
cross_encoder_name:
- cross-encoder/ms-marco-MiniLM-L-6-v2
- node_type: decision
metric: decision_accuracy
target_metric: decision_accuracy
search_space:
- module_name: threshold
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]
Expand Down
21 changes: 18 additions & 3 deletions tests/callback/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def test_pipeline_callbacks(dataset):
search_space = [
{
"node_type": "embedding",
"metric": "retrieval_hit_rate",
"target_metric": "retrieval_hit_rate",
"metrics": ["retrieval_map", "retrieval_mrr", "retrieval_ndcg", "retrieval_precision"],
"search_space": [
{
"module_name": "retrieval",
Expand All @@ -58,15 +59,29 @@ def test_pipeline_callbacks(dataset):
},
{
"node_type": "scoring",
"metric": "scoring_roc_auc",
"target_metric": "scoring_roc_auc",
"metrics": [
"scoring_accuracy",
"scoring_f1",
"scoring_log_likelihood",
"scoring_precision",
"scoring_recall",
],
"search_space": [
{"module_name": "knn", "k": [1], "weights": ["uniform", "distance"]},
{"module_name": "linear"},
],
},
{
"node_type": "decision",
"metric": "decision_accuracy",
"target_metric": "decision_accuracy",
"metrics": [
"decision_accuracy",
"decision_f1",
"decision_precision",
"decision_recall",
"decision_roc_auc",
],
"search_space": [{"module_name": "threshold", "thresh": [0.5]}, {"module_name": "argmax"}],
},
]
Expand Down
6 changes: 3 additions & 3 deletions tests/nodes/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_embedding_optimizer(multilabel: bool):
if multilabel:
metric = metric + "_intersecting"
embedding_optimizer_config = {
"metric": metric,
"target_metric": metric,
"node_type": "embedding",
"search_space": [
{
Expand All @@ -48,7 +48,7 @@ def scoring_optimizer_multiclass(embedding_optimizer_multiclass):
embedding_optimizer_multiclass.fit(context)

scoring_optimizer_config = {
"metric": "scoring_roc_auc",
"target_metric": "scoring_roc_auc",
"node_type": "scoring",
"search_space": [
{"module_name": "linear"},
Expand All @@ -64,7 +64,7 @@ def scoring_optimizer_multilabel(embedding_optimizer_multilabel):
embedding_optimizer_multilabel.fit(context)

scoring_optimizer_config = {
"metric": "scoring_roc_auc",
"target_metric": "scoring_roc_auc",
"node_type": "scoring",
"search_space": [
{"module_name": "linear"},
Expand Down
Loading