Skip to content

Commit 0bcef43

Browse files
committed
refactor: decision metric to target metric
1 parent dcee174 commit 0bcef43

File tree

10 files changed

+23
-22
lines changed

10 files changed

+23
-22
lines changed

autointent/_datafiles/default-multiclass-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# TODO: make up a better and more versatile config
22
- node_type: embedding
3-
decision_metric: retrieval_hit_rate
3+
target_metric: retrieval_hit_rate
44
search_space:
55
- module_name: retrieval
66
k: [10]
77
embedder_name:
88
- avsolatorio/GIST-small-Embedding-v0
99
- infgrad/stella-base-en-v2
1010
- node_type: scoring
11-
decision_metric: scoring_roc_auc
11+
target_metric: scoring_roc_auc
1212
search_space:
1313
- module_name: knn
1414
k: [1, 3, 5, 10]
@@ -20,7 +20,7 @@
2020
- cross-encoder/ms-marco-MiniLM-L-6-v2
2121
k: [1, 3, 5, 10]
2222
- node_type: decision
23-
decision_metric: decision_accuracy
23+
target_metric: decision_accuracy
2424
search_space:
2525
- module_name: threshold
2626
thresh: [0.5]

autointent/_datafiles/default-multilabel-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
# TODO: make up a better and more versatile config
22
- node_type: embedding
3-
decision_metric: retrieval_hit_rate_intersecting
3+
target_metric: retrieval_hit_rate_intersecting
44
search_space:
55
- module_name: retrieval
66
k: [10]
77
embedder_name:
88
- deepvk/USER-bge-m3
99
- node_type: scoring
10-
decision_metric: scoring_roc_auc
10+
target_metric: scoring_roc_auc
1111
search_space:
1212
- module_name: knn
1313
k: [3]
1414
weights: ["uniform", "distance", "closest"]
1515
- module_name: linear
1616
- node_type: decision
17-
decision_metric: decision_accuracy
17+
target_metric: decision_accuracy
1818
search_space:
1919
- module_name: threshold
2020
thresh: [0.5]

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def __init__(
3838
self.decision_metric_name = decision_metric
3939

4040
self.metrics = metrics if metrics is not None else []
41-
self.metrics.append(self.decision_metric_name)
41+
if self.decision_metric_name not in self.metrics:
42+
self.metrics.append(self.decision_metric_name)
4243

4344
self.modules_search_spaces = search_space # TODO search space validation
4445
self._logger = logging.getLogger(__name__) # TODO solve duplicate logging messages problem
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
- node_type: embedding
2-
decision_metric: retrieval_hit_rate
2+
target_metric: retrieval_hit_rate
33
search_space:
44
- module_name: retrieval
55
k: [10]
66
embedder_name:
77
- sentence-transformers/all-MiniLM-L6-v2
88
- node_type: scoring
9-
decision_metric: scoring_roc_auc
9+
target_metric: scoring_roc_auc
1010
search_space:
1111
- module_name: description
1212
temperature: [1.0, 0.5, 0.1, 0.05]
1313
- node_type: decision
14-
decision_metric: decision_accuracy
14+
target_metric: decision_accuracy
1515
search_space:
1616
- module_name: argmax

tests/assets/configs/multiclass.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
- node_type: embedding
2-
decision_metric: retrieval_hit_rate
2+
target_metric: retrieval_hit_rate
33
search_space:
44
- module_name: retrieval
55
k: [10]
66
embedder_name:
77
- sentence-transformers/all-MiniLM-L6-v2
88
- avsolatorio/GIST-small-Embedding-v0
99
- node_type: scoring
10-
decision_metric: scoring_roc_auc
10+
target_metric: scoring_roc_auc
1111
search_space:
1212
- module_name: knn
1313
k: [5, 10]
@@ -32,7 +32,7 @@
3232
cross_encoder_name:
3333
- cross-encoder/ms-marco-MiniLM-L-6-v2
3434
- node_type: decision
35-
decision_metric: decision_accuracy
35+
target_metric: decision_accuracy
3636
search_space:
3737
- module_name: threshold
3838
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]

tests/assets/configs/multilabel.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
- node_type: embedding
2-
decision_metric: scoring_accuracy
2+
target_metric: scoring_accuracy
33
search_space:
44
- module_name: logreg
55
cv: [2]
66
embedder_name:
77
- sentence-transformers/all-MiniLM-L6-v2
88
- avsolatorio/GIST-small-Embedding-v0
99
- node_type: scoring
10-
decision_metric: scoring_roc_auc
10+
target_metric: scoring_roc_auc
1111
search_space:
1212
- module_name: knn
1313
k: [5, 10]
@@ -28,7 +28,7 @@
2828
cross_encoder_name:
2929
- cross-encoder/ms-marco-MiniLM-L-6-v2
3030
- node_type: decision
31-
decision_metric: decision_accuracy
31+
target_metric: decision_accuracy
3232
search_space:
3333
- module_name: threshold
3434
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]

tests/nodes/test_decision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_decision_multiclass(scoring_optimizer_multiclass):
1515
scoring_optimizer_multiclass.fit(context)
1616

1717
decision_optimizer_config = {
18-
"metric": "decision_accuracy",
18+
"target_metric": "decision_accuracy",
1919
"node_type": "decision",
2020
"search_space": [
2121
{"module_name": "threshold", "thresh": [0.5]},
@@ -54,7 +54,7 @@ def test_decision_multilabel(scoring_optimizer_multilabel):
5454
scoring_optimizer_multilabel.fit(context)
5555

5656
decision_optimizer_config = {
57-
"metric": "decision_accuracy",
57+
"target_metric": "decision_accuracy",
5858
"node_type": "decision",
5959
"search_space": [
6060
{"module_name": "threshold", "thresh": [0.5]},

tests/nodes/test_logreg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def get_embedding_optimizer(multilabel: bool):
5656
if multilabel:
5757
metric = "scoring_neg_coverage"
5858
embedding_optimizer_config = {
59-
"metric": metric,
59+
"target_metric": metric,
6060
"node_type": "embedding",
6161
"search_space": [
6262
{

tests/nodes/test_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def get_embedding_optimizer(multilabel: bool):
5454
if multilabel:
5555
metric = metric + "_intersecting"
5656
embedding_optimizer_config = {
57-
"metric": metric,
57+
"target_metric": metric,
5858
"node_type": "embedding",
5959
"search_space": [
6060
{

tests/nodes/test_scoring.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_scoring_multiclass(embedding_optimizer_multiclass):
1616
embedding_optimizer_multiclass.fit(context)
1717

1818
scoring_optimizer_config = {
19-
"metric": "scoring_roc_auc",
19+
"target_metric": "scoring_roc_auc",
2020
"node_type": "scoring",
2121
"search_space": [
2222
{
@@ -80,7 +80,7 @@ def test_scoring_multilabel(embedding_optimizer_multilabel):
8080
embedding_optimizer_multilabel.fit(context)
8181

8282
scoring_optimizer_config = {
83-
"metric": "scoring_roc_auc",
83+
"target_metric": "scoring_roc_auc",
8484
"node_type": "scoring",
8585
"search_space": [
8686
{

0 commit comments

Comments
 (0)