Skip to content

Commit e8f9a4f

Browse files
committed
fix tests
1 parent e5ae3a5 commit e8f9a4f

File tree

11 files changed

+38
-82
lines changed

11 files changed

+38
-82
lines changed

autointent/modules/embedding/_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def from_context(
6565
cls,
6666
context: Context,
6767
k: int,
68-
embedder_config: EmbedderConfig,
68+
embedder_config: EmbedderConfig | str,
6969
) -> "RetrievalAimedEmbedding":
7070
"""
7171
Create an instance using a Context object.

autointent/modules/scoring/_description/description.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def get_embedder_name(self) -> str:
8181
8282
:return: Embedder name.
8383
"""
84-
return self.embedder_config.model_name
84+
return self.embedder_config
8585

8686
def fit(
8787
self,

autointent/modules/scoring/_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def get_embedder_name(self) -> str:
102102
103103
:return: Embedder name.
104104
"""
105-
return self.embedder_config.model_name
105+
return self.embedder_config
106106

107107
def fit(
108108
self,

autointent/modules/scoring/_mlknn/mlknn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def get_embedder_name(self) -> str:
115115
116116
:return: Embedder name.
117117
"""
118-
return self.embedder_config.model_name
118+
return self.embedder_config
119119

120120
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
121121
"""

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def fit(self, context: Context) -> None:
6666

6767
embedder_name = module.get_embedder_name()
6868
if embedder_name is not None:
69-
module_kwargs["embedder_name"] = embedder_name
69+
module_kwargs["embedder_config"] = embedder_name
7070

7171
self._logger.debug("optimizing %s module...", module_name)
7272
self.module_fit(module, context)

tests/assets/configs/multiclass.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
- cross-encoder/ms-marco-MiniLM-L-6-v2
1919
- avsolatorio/GIST-small-Embedding-v0
2020
k: [1, 3]
21-
train_head: [false, true]
21+
# train_head: [false, true]
2222
- module_name: sklearn
2323
embedder_name:
2424
- sergeyzh/rubert-tiny-turbo
@@ -29,7 +29,7 @@
2929
k: [ 5, 10 ]
3030
weights: [uniform, distance, closest]
3131
m: [ 2, 3 ]
32-
cross_encoder_name:
32+
cross_encoder_config:
3333
- cross-encoder/ms-marco-MiniLM-L-6-v2
3434
- node_type: decision
3535
target_metric: decision_accuracy

tests/callback/test_callback.py

Lines changed: 15 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from copy import deepcopy
12
from typing import Any
23

34
import numpy as np
@@ -18,7 +19,7 @@ def start_run(self, **kwargs: dict[str, Any]) -> None:
1819
self.history.append(("start_run", kwargs))
1920

2021
def start_module(self, **kwargs: dict[str, Any]) -> None:
21-
self.history.append(("start_module", kwargs))
22+
self.history.append(("start_module", deepcopy(kwargs)))
2223

2324
def log_value(self, **kwargs: dict[str, Any]) -> None:
2425
self.history.append(("log_value", kwargs))
@@ -52,7 +53,7 @@ def test_pipeline_callbacks(dataset):
5253
{
5354
"module_name": "retrieval",
5455
"k": [5, 10],
55-
"embedder_name": ["sergeyzh/rubert-tiny-turbo"],
56+
"embedder_config": ["sergeyzh/rubert-tiny-turbo"],
5657
}
5758
],
5859
},
@@ -98,88 +99,39 @@ def test_pipeline_callbacks(dataset):
9899
(
99100
"start_module",
100101
{
102+
"module_kwargs": {"embedder_config": "sergeyzh/rubert-tiny-turbo", "k": 5},
101103
"module_name": "retrieval",
102104
"num": 0,
103-
"module_kwargs": {"k": 5, "embedder_name": "sergeyzh/rubert-tiny-turbo"},
104-
},
105-
),
106-
(
107-
"log_metric",
108-
{
109-
"metrics": {
110-
"retrieval_hit_rate": 1.0,
111-
}
112105
},
113106
),
107+
("log_metric", {"metrics": {"retrieval_hit_rate": 1.0}}),
114108
("end_module", {}),
115109
(
116110
"start_module",
117111
{
112+
"module_kwargs": {"embedder_config": "sergeyzh/rubert-tiny-turbo", "k": 10},
118113
"module_name": "retrieval",
119114
"num": 1,
120-
"module_kwargs": {"k": 10, "embedder_name": "sergeyzh/rubert-tiny-turbo"},
121-
},
122-
),
123-
(
124-
"log_metric",
125-
{
126-
"metrics": {
127-
"retrieval_hit_rate": 1.0,
128-
}
129115
},
130116
),
117+
("log_metric", {"metrics": {"retrieval_hit_rate": 1.0}}),
131118
("end_module", {}),
132119
(
133120
"start_module",
134-
{
135-
"module_name": "knn",
136-
"num": 0,
137-
"module_kwargs": {"k": 1, "weights": "uniform", "embedder_name": "sergeyzh/rubert-tiny-turbo"},
138-
},
139-
),
140-
(
141-
"log_metric",
142-
{
143-
"metrics": {
144-
"scoring_accuracy": 1.0,
145-
"scoring_roc_auc": 1.0,
146-
}
147-
},
121+
{"module_kwargs": {"embedder_config": None, "k": 1, "weights": "uniform"}, "module_name": "knn", "num": 0},
148122
),
123+
("log_metric", {"metrics": {"scoring_accuracy": 1.0, "scoring_roc_auc": 1.0}}),
149124
("end_module", {}),
150125
(
151126
"start_module",
152-
{
153-
"module_name": "knn",
154-
"num": 1,
155-
"module_kwargs": {"k": 1, "weights": "distance", "embedder_name": "sergeyzh/rubert-tiny-turbo"},
156-
},
157-
),
158-
(
159-
"log_metric",
160-
{
161-
"metrics": {
162-
"scoring_accuracy": 1.0,
163-
"scoring_roc_auc": 1.0,
164-
}
165-
},
127+
{"module_kwargs": {"embedder_config": None, "k": 1, "weights": "distance"}, "module_name": "knn", "num": 1},
166128
),
129+
("log_metric", {"metrics": {"scoring_accuracy": 1.0, "scoring_roc_auc": 1.0}}),
167130
("end_module", {}),
168-
(
169-
"start_module",
170-
{"module_name": "linear", "num": 0, "module_kwargs": {"embedder_name": "sergeyzh/rubert-tiny-turbo"}},
171-
),
172-
(
173-
"log_metric",
174-
{
175-
"metrics": {
176-
"scoring_accuracy": 0.75,
177-
"scoring_roc_auc": 1.0,
178-
}
179-
},
180-
),
131+
("start_module", {"module_kwargs": {"embedder_config": None}, "module_name": "linear", "num": 0}),
132+
("log_metric", {"metrics": {"scoring_accuracy": 0.75, "scoring_roc_auc": 1.0}}),
181133
("end_module", {}),
182-
("start_module", {"module_name": "threshold", "num": 0, "module_kwargs": {"thresh": 0.5}}),
134+
("start_module", {"module_kwargs": {"thresh": 0.5}, "module_name": "threshold", "num": 0}),
183135
(
184136
"log_metric",
185137
{
@@ -193,7 +145,7 @@ def test_pipeline_callbacks(dataset):
193145
},
194146
),
195147
("end_module", {}),
196-
("start_module", {"module_name": "argmax", "num": 0, "module_kwargs": {}}),
148+
("start_module", {"module_kwargs": {}, "module_name": "argmax", "num": 0}),
197149
(
198150
"log_metric",
199151
{

tests/configs/test_embedding.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@ def valid_embedding_config():
1212
"node_type": "embedding",
1313
"target_metric": "retrieval_mrr",
1414
"search_space": [
15-
{"module_name": "logreg_embedding", "embedder_name": ["sergeyzh/rubert-tiny-turbo"], "cv": [3, 5]},
16-
{"module_name": "retrieval", "embedder_name": ["sentence-transformers/all-MiniLM-L6-v2"], "k": [5, 10]},
15+
{"module_name": "logreg_embedding", "embedder_config": ["sergeyzh/rubert-tiny-turbo"], "cv": [3, 5]},
16+
{
17+
"module_name": "retrieval",
18+
"embedder_config": ["sentence-transformers/all-MiniLM-L6-v2"],
19+
"k": [5, 10],
20+
},
1721
],
1822
}
1923
]

tests/configs/test_scoring.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,40 +15,40 @@ def valid_scoring_config():
1515
{
1616
"module_name": "dnnc",
1717
"cross_encoder_name": ["cross-encoder/ms-marco-MiniLM-L-6-v2"],
18-
"embedder_name": ["sergeyzh/rubert-tiny-turbo"],
18+
"embedder_config": ["sergeyzh/rubert-tiny-turbo"],
1919
"k": [5, 10],
2020
"train_head": [False, True],
2121
},
2222
{
2323
"module_name": "knn",
24-
"embedder_name": ["sentence-transformers/all-MiniLM-L6-v2"],
24+
"embedder_config": ["sentence-transformers/all-MiniLM-L6-v2"],
2525
"k": [5, 10],
2626
"weights": ["uniform", "distance"],
2727
},
28-
{"module_name": "linear", "embedder_name": ["sergeyzh/rubert-tiny-turbo"], "cv": [3, 5]},
28+
{"module_name": "linear", "embedder_config": ["sergeyzh/rubert-tiny-turbo"], "cv": [3, 5]},
2929
{
3030
"module_name": "mlknn",
31-
"embedder_name": ["sergeyzh/rubert-tiny-turbo"],
31+
"embedder_config": ["sergeyzh/rubert-tiny-turbo"],
3232
"k": [5, 10],
3333
"s": [1.0, 0.5],
3434
"ignore_first_neighbours": [0, 1],
3535
},
3636
{
3737
"module_name": "description",
38-
"embedder_name": ["sentence-transformers/all-MiniLM-L6-v2"],
38+
"embedder_config": ["sentence-transformers/all-MiniLM-L6-v2"],
3939
"temperature": [0.5, 1.0],
4040
},
4141
{
4242
"module_name": "rerank",
43-
"cross_encoder_name": ["cross-encoder/ms-marco-MiniLM-L-6-v2"],
43+
"embedder_config": ["cross-encoder/ms-marco-MiniLM-L-6-v2"],
4444
"embedder_name": ["sergeyzh/rubert-tiny-turbo"],
4545
"k": [5],
4646
"weights": ["distance"],
4747
"rank_threshold_cutoff": [None, 3],
4848
},
4949
{
5050
"module_name": "sklearn",
51-
"embedder_name": ["sentence-transformers/all-MiniLM-L6-v2"],
51+
"embedder_config": ["sentence-transformers/all-MiniLM-L6-v2"],
5252
"clf_name": ["LogisticRegression"],
5353
"clf_args": [{"C": 1.0}, {"C": 0.5}],
5454
},

tests/modules/decision/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def multiclass_fit_data(dataset):
1212
knn_params = {
1313
"k": 3,
1414
"weights": "distance",
15-
"embedder_name": "sergeyzh/rubert-tiny-turbo",
15+
"embedder_config": "sergeyzh/rubert-tiny-turbo",
1616
}
1717
scorer = KNNScorer(**knn_params)
1818

@@ -29,7 +29,7 @@ def multilabel_fit_data(dataset):
2929
knn_params = {
3030
"k": 3,
3131
"weights": "distance",
32-
"embedder_name": "sergeyzh/rubert-tiny-turbo",
32+
"embedder_config": "sergeyzh/rubert-tiny-turbo",
3333
}
3434
scorer = KNNScorer(**knn_params)
3535

0 commit comments

Comments
 (0)