Skip to content

Commit 0a48e3c

Browse files
committed
add sklearn scorer
1 parent b8f7151 commit 0a48e3c

File tree

4 files changed

+183
-6
lines changed

4 files changed

+183
-6
lines changed

autointent/modules/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from .regexp import RegExp
1313
from .retrieval import RetrievalModule, VectorDBModule
14-
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, ScoringModule
14+
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, ScoringModule, SklearnScorer
1515

1616
T = TypeVar("T", bound=Module)
1717

@@ -25,11 +25,11 @@ def create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
2525
RETRIEVAL_MODULES_MULTILABEL = RETRIEVAL_MODULES_MULTICLASS
2626

2727
SCORING_MODULES_MULTICLASS: dict[str, type[ScoringModule]] = create_modules_dict(
28-
[DNNCScorer, KNNScorer, LinearScorer, DescriptionScorer]
28+
[DNNCScorer, KNNScorer, LinearScorer, DescriptionScorer, SklearnScorer]
2929
)
3030

3131
SCORING_MODULES_MULTILABEL: dict[str, type[ScoringModule]] = create_modules_dict(
32-
[MLKnnScorer, LinearScorer, DescriptionScorer]
32+
[MLKnnScorer, LinearScorer, DescriptionScorer, SklearnScorer]
3333
)
3434

3535
PREDICTION_MODULES_MULTICLASS: dict[str, type[Module]] = create_modules_dict(
@@ -42,8 +42,7 @@ def create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
4242

4343
__all__ = [
4444
"Module",
45-
"AdaptivePredictor"
46-
"ArgmaxPredictor",
45+
"AdaptivePredictor" "ArgmaxPredictor",
4746
"JinoosPredictor",
4847
"PredictionModule",
4948
"ThresholdPredictor",
@@ -57,6 +56,7 @@ def create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
5756
"MLKnnScorer",
5857
"DescriptionScorer",
5958
"ScoringModule",
59+
"SklearnScorer",
6060
"RETRIEVAL_MODULES_MULTICLASS",
6161
"RETRIEVAL_MODULES_MULTILABEL",
6262
"SCORING_MODULES_MULTICLASS",

autointent/modules/scoring/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,14 @@
44
from .knn import KNNScorer
55
from .linear import LinearScorer
66
from .mlknn import MLKnnScorer
7+
from .sklearn import SklearnScorer
78

8-
__all__ = ["ScoringModule", "DNNCScorer", "KNNScorer", "LinearScorer", "MLKnnScorer", "DescriptionScorer"]
9+
__all__ = [
10+
"ScoringModule",
11+
"DNNCScorer",
12+
"KNNScorer",
13+
"LinearScorer",
14+
"MLKnnScorer",
15+
"DescriptionScorer",
16+
"SklearnScorer",
17+
]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .scorer import SklearnScorer
2+
3+
__all__ = ["SklearnScorer"]
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import json
2+
from pathlib import Path
3+
from typing import Any
4+
5+
import joblib
6+
import numpy as np
7+
import numpy.typing as npt
8+
from sklearn.multioutput import MultiOutputClassifier
9+
from sklearn.utils import all_estimators
10+
from typing_extensions import Self
11+
12+
from autointent.context import Context
13+
from autointent.context.embedder import Embedder
14+
from autointent.context.vector_index_client import VectorIndexClient
15+
from autointent.custom_types import BaseMetadataDict, LabelType
16+
from autointent.modules.scoring.base import ScoringModule
17+
18+
AVAILIABLE_CLASSIFIERS = {name: class_ for name, class_ in all_estimators() if hasattr(class_, "predict_proba")}
19+
20+
21+
class SklearnScorerDumpDict(BaseMetadataDict):
22+
multilabel: bool
23+
batch_size: int
24+
max_length: int | None
25+
26+
27+
class SklearnScorer(ScoringModule):
28+
classifier_file_name: str = "classifier.joblib"
29+
embedding_model_subdir: str = "embedding_model"
30+
precomputed_embeddings: bool = False
31+
db_dir: str
32+
name = "sklearn"
33+
34+
def __init__(
35+
self,
36+
model_name: str,
37+
clf_name: str,
38+
cv: int = 3,
39+
clf_args: dict = {}, # noqa: B006
40+
n_jobs: int = -1,
41+
device: str = "cpu",
42+
seed: int = 0,
43+
batch_size: int = 32,
44+
max_length: int | None = None,
45+
) -> None:
46+
self.cv = cv
47+
self.n_jobs = n_jobs
48+
self.device = device
49+
self.seed = seed
50+
self.model_name = model_name
51+
self.batch_size = batch_size
52+
self.max_length = max_length
53+
self.clf_name = clf_name
54+
self.clf_args = clf_args
55+
56+
@classmethod
57+
def from_context(
58+
cls,
59+
context: Context,
60+
clf_name: str,
61+
clf_args: dict = {}, # noqa: B006
62+
model_name: str | None = None,
63+
) -> Self:
64+
if model_name is None:
65+
model_name = context.optimization_info.get_best_embedder()
66+
precomputed_embeddings = True
67+
else:
68+
precomputed_embeddings = context.vector_index_client.exists(model_name)
69+
context.device = context.get_device()
70+
context.embedder_batch_size = context.get_batch_size()
71+
context.embedder_max_length = context.get_max_length()
72+
context.db_dir = context.get_db_dir()
73+
instance = cls(
74+
model_name=model_name,
75+
device=context.device,
76+
seed=context.seed,
77+
batch_size=context.embedder_batch_size,
78+
max_length=context.embedder_max_length,
79+
clf_name=clf_name,
80+
clf_args=clf_args,
81+
)
82+
instance.precomputed_embeddings = precomputed_embeddings
83+
instance.db_dir = str(context.db_dir)
84+
return instance
85+
86+
def fit(
87+
self,
88+
utterances: list[str],
89+
labels: list[LabelType],
90+
) -> None:
91+
self._multilabel = isinstance(labels[0], list)
92+
93+
if self.precomputed_embeddings:
94+
# this happens only when LinearScorer is within Pipeline opimization after RetrievalNode optimization
95+
vector_index_client = VectorIndexClient(self.device, self.db_dir, self.batch_size, self.max_length)
96+
vector_index = vector_index_client.get_index(self.model_name)
97+
features = vector_index.get_all_embeddings()
98+
if len(features) != len(utterances):
99+
msg = "Vector index mismatches provided utterances"
100+
raise ValueError(msg)
101+
embedder = vector_index.embedder
102+
else:
103+
embedder = Embedder(
104+
device=self.device, model_name=self.model_name, batch_size=self.batch_size, max_length=self.max_length
105+
)
106+
features = embedder.embed(utterances)
107+
base_clf = AVAILIABLE_CLASSIFIERS.get(self.clf_name)(**self.clf_args)
108+
109+
clf = MultiOutputClassifier(base_clf) if self._multilabel else base_clf
110+
111+
clf.fit(features, labels)
112+
113+
self._clf = clf
114+
self._embedder = embedder
115+
116+
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
117+
features = self._embedder.embed(utterances)
118+
probas = self._clf.predict_proba(features)
119+
if self._multilabel:
120+
probas = np.stack(probas, axis=1)[..., 1]
121+
return probas # type: ignore[no-any-return]
122+
123+
def clear_cache(self) -> None:
124+
self._embedder.delete()
125+
126+
def dump(self, path: str) -> None:
127+
self.metadata = SklearnScorerDumpDict(
128+
multilabel=self._multilabel,
129+
batch_size=self.batch_size,
130+
max_length=self.max_length,
131+
)
132+
133+
dump_dir = Path(path)
134+
135+
metadata_path = dump_dir / self.metadata_dict_name
136+
with metadata_path.open("w") as file:
137+
json.dump(self.metadata, file, indent=4)
138+
139+
# dump sklearn model
140+
clf_path = dump_dir / self.classifier_file_name
141+
joblib.dump(self._clf, clf_path)
142+
143+
# dump sentence transformer model
144+
self._embedder.dump(dump_dir / self.embedding_model_subdir)
145+
146+
def load(self, path: str) -> None:
147+
dump_dir = Path(path)
148+
149+
metadata_path = dump_dir / self.metadata_dict_name
150+
with metadata_path.open() as file:
151+
metadata: SklearnScorerDumpDict = json.load(file)
152+
self._multilabel = metadata["multilabel"]
153+
154+
# load sklearn model
155+
clf_path = dump_dir / self.classifier_file_name
156+
self._clf = joblib.load(clf_path)
157+
158+
# load sentence transformer model
159+
embedder_dir = dump_dir / self.embedding_model_subdir
160+
self._embedder = Embedder(
161+
device=self.device,
162+
model_name=embedder_dir,
163+
batch_size=metadata["batch_size"],
164+
max_length=metadata["max_length"],
165+
)

0 commit comments

Comments
 (0)