Skip to content

Commit 51dcc7e

Browse files
committed
update metrics
1 parent 9faa681 commit 51dcc7e

File tree

7 files changed

+33
-29
lines changed

7 files changed

+33
-29
lines changed

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def __init__(
3535
"""
3636
self.node_type = node_type
3737
self.node_info = NODES_INFO[node_type]
38-
self.decision_metric_name = target_metric
38+
self.target_metric = target_metric
3939

4040
self.metrics = metrics if metrics is not None else []
41-
if self.decision_metric_name not in self.metrics:
42-
self.metrics.append(self.decision_metric_name)
41+
if self.target_metric not in self.metrics:
42+
self.metrics.append(self.target_metric)
4343

4444
self.modules_search_spaces = search_space # TODO search space validation
4545
self._logger = logging.getLogger(__name__) # TODO solve duplicate logging messages problem
@@ -73,7 +73,7 @@ def fit(self, context: Context) -> None:
7373

7474
self._logger.debug("scoring %s module...", module_name)
7575
metrics_score = module.score(context, "validation", self.metrics)
76-
metric_value = metrics_score[self.decision_metric_name]
76+
metric_value = metrics_score[self.target_metric]
7777

7878
context.callback_handler.log_metrics(metrics_score)
7979
context.callback_handler.end_module()
@@ -91,7 +91,7 @@ def fit(self, context: Context) -> None:
9191
module_name,
9292
module_kwargs,
9393
metric_value,
94-
self.decision_metric_name,
94+
self.target_metric,
9595
module.get_assets(), # retriever name / scores / predictions
9696
module_dump_dir,
9797
module=module if not context.is_ram_to_clear() else None,

autointent/nodes/schemes.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class DecisionNodeValidator(BaseModel):
5656
"""Search space configuration for the Decision node."""
5757

5858
node_type: NodeType = NodeType.decision
59-
metric: DecisionMetrics
59+
target_metric: DecisionMetrics
60+
metrics: list[DecisionMetrics] | None = None
6061
search_space: list[DecisionSearchSpaceType]
6162

6263

@@ -70,7 +71,8 @@ class EmbeddingNodeValidator(BaseModel):
7071
"""Search space configuration for the Embedding node."""
7172

7273
node_type: NodeType = NodeType.embedding
73-
metric: EmbeddingMetrics
74+
target_metric: EmbeddingMetrics
75+
metrics: list[EmbeddingMetrics] | None = None
7476
search_space: list[EmbeddingSearchSpaceType]
7577

7678

@@ -84,7 +86,8 @@ class ScoringNodeValidator(BaseModel):
8486
"""Search space configuration for the Scoring node."""
8587

8688
node_type: NodeType = NodeType.scoring
87-
metric: ScoringMetrics
89+
target_metric: ScoringMetrics
90+
metrics: list[ScoringMetrics] | None = None
8891
search_space: list[ScoringSearchSpaceType]
8992

9093

@@ -98,7 +101,8 @@ class RegexNodeValidator(BaseModel):
98101
"""Search space configuration for the Regexp node."""
99102

100103
node_type: NodeType = NodeType.regexp
101-
metric: RegexpMetrics
104+
target_metric: RegexpMetrics
105+
metrics: list[RegexpMetrics] | None = None
102106
search_space: list[RegexpSearchSpaceType]
103107

104108

tests/configs/test_combined_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def valid_optimizer_config():
1313
return [
1414
{
1515
"node_type": "scoring",
16-
"metric": "scoring_roc_auc",
16+
"target_metric": "scoring_roc_auc",
1717
"search_space": [
1818
{
1919
"module_name": "dnnc",
@@ -28,7 +28,7 @@ def valid_optimizer_config():
2828
},
2929
{
3030
"node_type": "embedding",
31-
"metric": "retrieval_hit_rate",
31+
"target_target_metric": "retrieval_hit_rate",
3232
"search_space": [
3333
{
3434
"module_name": "retrieval",
@@ -62,7 +62,7 @@ def test_invalid_optimizer_config_missing_field():
6262
invalid_config = [
6363
{
6464
"node_type": "scoring",
65-
# Missing "metric"
65+
# Missing "target_metric"
6666
"search_space": [
6767
{"module_name": "dnnc", "cross_encoder_name": ["cross-encoder/ms-marco-MiniLM-L-6-v2"], "k": [1, 3]}
6868
],
@@ -78,7 +78,7 @@ def test_invalid_optimizer_config_wrong_type():
7878
invalid_config = [
7979
{
8080
"node_type": "scoring",
81-
"metric": "scoring_roc_auc",
81+
"target_metric": "scoring_roc_auc",
8282
"search_space": [
8383
{
8484
"module_name": "dnnc",

tests/configs/test_decision.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def valid_decision_config():
1010
return [
1111
{
1212
"node_type": "decision",
13-
"metric": "decision_roc_auc",
13+
"target_metric": "decision_roc_auc",
1414
"search_space": [
1515
{"module_name": "argmax"},
1616
{"module_name": "jinoos", "search_space": [[0.3, 0.5, 0.7]]},
@@ -29,7 +29,7 @@ def test_valid_decision_config(valid_decision_config):
2929
"""Test that a valid decision config passes validation."""
3030
config = OptimizationConfig(valid_decision_config)
3131
assert config[0].node_type == "decision"
32-
assert config[0].metric == "decision_roc_auc"
32+
assert config[0].target_metric == "decision_roc_auc"
3333
assert isinstance(config[0].search_space, list)
3434
assert config[0].search_space[0].module_name == "argmax"
3535

@@ -39,7 +39,7 @@ def test_invalid_decision_config_missing_field():
3939
invalid_config = [
4040
{
4141
"node_type": "decision",
42-
# Missing "metric"
42+
# Missing "target_metric"
4343
"search_space": [{"module_name": "tunable", "n_trials": [100]}],
4444
}
4545
]
@@ -53,7 +53,7 @@ def test_invalid_decision_config_wrong_type():
5353
invalid_config = [
5454
{
5555
"node_type": "decision",
56-
"metric": "decision_roc_auc",
56+
"target_metric": "decision_roc_auc",
5757
"search_space": [
5858
{
5959
"module_name": "threshold",

tests/configs/test_embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def valid_embedding_config():
1010
return [
1111
{
1212
"node_type": "embedding",
13-
"metric": "retrieval_mrr",
13+
"target_metric": "retrieval_mrr",
1414
"search_space": [
1515
{"module_name": "logreg_embedding", "embedder_name": ["sergeyzh/rubert-tiny-turbo"], "cv": [3, 5]},
1616
{"module_name": "retrieval", "embedder_name": ["sentence-transformers/all-MiniLM-L6-v2"], "k": [5, 10]},
@@ -23,7 +23,7 @@ def test_valid_embedding_config(valid_embedding_config):
2323
"""Test that a valid embedding config passes validation."""
2424
config = OptimizationConfig(valid_embedding_config)
2525
assert config[0].node_type == "embedding"
26-
assert config[0].metric == "retrieval_mrr"
26+
assert config[0].target_metric == "retrieval_mrr"
2727
assert isinstance(config[0].search_space, list)
2828
assert config[0].search_space[0].module_name == "logreg_embedding"
2929
assert "embedder_name" in config[0].search_space[0].model_dump()
@@ -34,7 +34,7 @@ def test_invalid_embedding_config_missing_field():
3434
invalid_config = [
3535
{
3636
"node_type": "embedding",
37-
# Missing "metric"
37+
# Missing "target_metric"
3838
"search_space": [
3939
{"module_name": "retrieval", "embedder_name": ["sentence-transformers/all-MiniLM-L6-v2"], "k": [5, 10]}
4040
],
@@ -50,7 +50,7 @@ def test_invalid_embedding_config_wrong_type():
5050
invalid_config = [
5151
{
5252
"node_type": "embedding",
53-
"metric": "retrieval_mrr",
53+
"target_metric": "retrieval_mrr",
5454
"search_space": [
5555
{
5656
"module_name": "logreg_embedding",

tests/configs/test_regex.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
@pytest.fixture
88
def valid_regexp_config():
99
"""Fixture for a valid RegExp node configuration."""
10-
return [{"node_type": "regexp", "metric": "regexp_partial_accuracy", "search_space": [{"module_name": "regexp"}]}]
10+
return [{"node_type": "regexp", "target_metric": "regexp_partial_accuracy", "search_space": [{"module_name": "regexp"}]}]
1111

1212

1313
def test_valid_regexp_config(valid_regexp_config):
1414
"""Test that a valid RegExp config passes validation."""
1515
config = OptimizationConfig(valid_regexp_config)
1616
assert config[0].node_type == "regexp"
17-
assert config[0].metric == "regexp_partial_accuracy"
17+
assert config[0].target_metric == "regexp_partial_accuracy"
1818
assert isinstance(config[0].search_space, list)
1919
assert config[0].search_space[0].module_name == "regexp"
2020

@@ -23,7 +23,7 @@ def test_invalid_regexp_config_missing_field():
2323
"""Test that a missing required field raises ValidationError."""
2424
invalid_config = {
2525
"node_type": "regexp",
26-
# Missing "metric"
26+
# Missing "target_metric"
2727
"search_space": [{"module_name": "regexp"}],
2828
}
2929

@@ -35,7 +35,7 @@ def test_invalid_regexp_config_wrong_type():
3535
"""Test that an invalid field type raises ValidationError."""
3636
invalid_config = {
3737
"node_type": "regexp",
38-
"metric": "regexp_partial_accuracy",
38+
"target_metric": "regexp_partial_accuracy",
3939
"search_space": "should_be_a_list", # Should be a list of dicts
4040
}
4141

tests/configs/test_scoring.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def valid_scoring_config():
1010
return [
1111
{
1212
"node_type": "scoring",
13-
"metric": "scoring_roc_auc",
13+
"target_metric": "scoring_roc_auc",
1414
"search_space": [
1515
{
1616
"module_name": "dnnc",
@@ -61,7 +61,7 @@ def test_valid_scoring_config(valid_scoring_config):
6161
"""Test that a valid scoring config passes validation."""
6262
config = OptimizationConfig(valid_scoring_config)
6363
assert config[0].node_type == "scoring"
64-
assert config[0].metric == "scoring_roc_auc"
64+
assert config[0].target_metric == "scoring_roc_auc"
6565
assert isinstance(config[0].search_space, list)
6666
assert config[0].search_space[0].module_name == "dnnc"
6767

@@ -70,7 +70,7 @@ def test_invalid_scoring_config_missing_field():
7070
"""Test that a missing required field raises ValidationError."""
7171
invalid_config = {
7272
"node_type": "scoring",
73-
# Missing "metric"
73+
# Missing "target_metric"
7474
"search_space": [
7575
{"module_name": "dnnc", "cross_encoder_name": ["cross-encoder/ms-marco-MiniLM-L-6-v2"], "k": [5, 10]}
7676
],
@@ -84,7 +84,7 @@ def test_invalid_scoring_config_wrong_type():
8484
"""Test that an invalid field type raises ValidationError."""
8585
invalid_config = {
8686
"node_type": "scoring",
87-
"metric": "scoring_roc_auc",
87+
"target_metric": "scoring_roc_auc",
8888
"search_space": [
8989
{
9090
"module_name": "knn",

0 commit comments

Comments
 (0)