Skip to content

Commit 0c3b3bc

Browse files
committed
add test
1 parent cf325f1 commit 0c3b3bc

File tree

4 files changed

+82
-8
lines changed

4 files changed

+82
-8
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .scorer import SklearnScorer
1+
from .sklearn_scorer import SklearnScorer
22

33
__all__ = ["SklearnScorer"]

autointent/modules/scoring/_sklearn/scorer.py renamed to autointent/modules/scoring/_sklearn/sklearn_scorer.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23
from pathlib import Path
34
from typing import Any
45

@@ -14,7 +15,19 @@
1415
from autointent.custom_types import BaseMetadataDict, LabelType
1516
from autointent.modules.abc import ScoringModule
1617

17-
AVAILIABLE_CLASSIFIERS = {name: class_ for name, class_ in all_estimators() if hasattr(class_, "predict_proba")}
18+
logger = logging.getLogger(__name__)
19+
AVAILIABLE_CLASSIFIERS = {
20+
name: class_
21+
for name, class_ in all_estimators(
22+
type_filter=[
23+
# remove transformer (e.g. TfidfTransformer) from the list of available classifiers
24+
"classifier",
25+
"regressor",
26+
"cluster",
27+
]
28+
)
29+
if hasattr(class_, "predict_proba")
30+
}
1831

1932

2033
class SklearnScorerDumpDict(BaseMetadataDict):
@@ -64,7 +77,7 @@ def __init__(
6477
max_length: int | None = None,
6578
) -> None:
6679
"""
67-
Initialize the LinearScorer.
80+
Initialize the SklearnScorer.
6881
6982
:param embedder_name: Name of the embedder model.
7083
:param clf_name: Name of the sklearn classifier to use.
@@ -84,7 +97,7 @@ def __init__(
8497
self.batch_size = batch_size
8598
self.max_length = max_length
8699
self.clf_name = clf_name
87-
self.clf_args = clf_args
100+
self.clf_args = clf_args or {}
88101

89102
@classmethod
90103
def from_context(
@@ -122,7 +135,7 @@ def fit(
122135
labels: list[LabelType],
123136
) -> None:
124137
"""
125-
Train the chosen skearn classifier.
138+
Train the chosen sklearn classifier.
126139
127140
:param utterances: List of training utterances.
128141
:param labels: List of labels corresponding to the utterances.
@@ -137,11 +150,11 @@ def fit(
137150
max_length=self.max_length,
138151
)
139152
features = embedder.embed(utterances)
140-
self.clf_args = {} if self.clf_args is None else self.clf_args
141153
if AVAILIABLE_CLASSIFIERS.get(self.clf_name):
142154
base_clf = AVAILIABLE_CLASSIFIERS[self.clf_name](**self.clf_args)
143155
else:
144156
msg = f"Class {self.clf_name} does not exist in sklearn or does not have predict_proba method"
157+
logger.error(msg)
145158
raise ValueError(msg)
146159

147160
clf = MultiOutputClassifier(base_clf) if self._multilabel else base_clf
@@ -170,7 +183,7 @@ def clear_cache(self) -> None:
170183

171184
def dump(self, path: str) -> None:
172185
"""
173-
Save the LinearScorer's metadata, classifier, and embedder to disk.
186+
Save the SklearnScorer's metadata, classifier, and embedder to disk.
174187
175188
:param path: Path to the directory where assets will be dumped.
176189
"""

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from autointent import Dataset
88

99

10-
def setup_environment() -> tuple[Path, Path, Path]:
10+
def setup_environment() -> tuple[Path, Path]:
1111
logs_dir = ires.files("tests").joinpath("logs") / str(uuid4())
1212
dump_dir = logs_dir / "modules_dump"
1313
return dump_dir, logs_dir
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import numpy as np
2+
3+
from autointent.context.data_handler import DataHandler
4+
from autointent.modules import SklearnScorer
5+
from tests.conftest import setup_environment
6+
7+
8+
def test_base_linear(dataset):
9+
dump_dir, logs_dir = setup_environment()
10+
11+
data_handler = DataHandler(dataset)
12+
13+
scorer = SklearnScorer(embedder_name="sergeyzh/rubert-tiny-turbo", clf_name="LogisticRegression")
14+
15+
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
16+
test_data = [
17+
"why is there a hold on my american saving bank account",
18+
"i am nost sure why my account is blocked",
19+
"why is there a hold on my capital one checking account",
20+
"i think my account is blocked but i do not know the reason",
21+
"can you tell me why is my bank account frozen",
22+
]
23+
predictions = scorer.predict(test_data)
24+
25+
np.testing.assert_almost_equal(
26+
np.array(
27+
[
28+
[
29+
0.23748632,
30+
0.39067508,
31+
0.2393372,
32+
0.13250139,
33+
],
34+
[0.23913757, 0.37610976, 0.24952359, 0.13522908],
35+
[
36+
0.25714506,
37+
0.34984371,
38+
0.25495681,
39+
0.13805442,
40+
],
41+
[
42+
0.2571957,
43+
0.34850898,
44+
0.25346288,
45+
0.14083245,
46+
],
47+
[
48+
0.23885061,
49+
0.41527567,
50+
0.21830964,
51+
0.12756408,
52+
],
53+
],
54+
),
55+
predictions,
56+
decimal=2,
57+
)
58+
59+
predictions, metadata = scorer.predict_with_metadata(test_data)
60+
assert len(predictions) == len(test_data)
61+
assert metadata is None

0 commit comments

Comments
 (0)