Skip to content

Commit f6db40d

Browse files
committed
update
1 parent 0c3b3bc commit f6db40d

File tree

1 file changed

+24
-95
lines changed

1 file changed

+24
-95
lines changed
Lines changed: 24 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
import json
21
import logging
3-
from pathlib import Path
42
from typing import Any
53

6-
import joblib
74
import numpy as np
85
import numpy.typing as npt
96
from sklearn.linear_model import LogisticRegression
@@ -12,11 +9,11 @@
129
from typing_extensions import Self
1310

1411
from autointent import Context, Embedder
15-
from autointent.custom_types import BaseMetadataDict, LabelType
12+
from autointent.custom_types import LabelType
1613
from autointent.modules.abc import ScoringModule
1714

1815
logger = logging.getLogger(__name__)
19-
AVAILIABLE_CLASSIFIERS = {
16+
AVAILABLE_CLASSIFIERS = {
2017
name: class_
2118
for name, class_ in all_estimators(
2219
type_filter=[
@@ -30,51 +27,30 @@
3027
}
3128

3229

33-
class SklearnScorerDumpDict(BaseMetadataDict):
34-
"""
35-
Metadata for dumping the state of a SklearnScorer.
36-
37-
:ivar multilabel: Whether the task is multilabel classification.
38-
:ivar batch_size: Batch size used for embedding.
39-
:ivar max_length: Maximum sequence length for embedding, or None if not specified.
40-
"""
41-
42-
multilabel: bool
43-
batch_size: int
44-
max_length: int | None
45-
46-
4730
class SklearnScorer(ScoringModule):
4831
"""
4932
Scoring module for classification using sklearn classifiers with implemented predict_proba() method.
5033
5134
This module uses embeddings generated from a transformer model to train
5235
chosen sklearn classifier for intent classification.
5336
54-
:ivar classifier_file_name: Filename for saving the classifier to disk.
55-
:ivar embedding_model_subdir: Directory for saving the embedding model.
56-
:ivar precomputed_embeddings: Flag indicating if embeddings are precomputed.
57-
:ivar db_dir: Path to the database directory.
5837
:ivar name: Name of the scorer, defaults to "linear".
5938
"""
6039

61-
classifier_file_name: str = "classifier.joblib"
62-
embedding_model_subdir: str = "embedding_model"
63-
precomputed_embeddings: bool = False
64-
db_dir: str
6540
name = "sklearn"
6641

6742
def __init__(
6843
self,
6944
embedder_name: str,
7045
clf_name: str,
46+
embedder_batch_size: int = 32,
47+
embedder_max_length: int | None = None,
48+
embedder_device: str = "cpu",
49+
embedder_use_cache: bool = True,
7150
cv: int = 3,
7251
clf_args: dict[str, Any] | None = None,
7352
n_jobs: int = -1,
74-
device: str = "cpu",
7553
seed: int = 0,
76-
batch_size: int = 32,
77-
max_length: int | None = None,
7854
) -> None:
7955
"""
8056
Initialize the SklearnScorer.
@@ -84,20 +60,22 @@ def __init__(
8460
:param cv: Number of cross-validation folds, defaults to 3.
8561
:param clf_args: dictionary with the chosen sklearn classifier arguments, defaults to {}.
8662
:param n_jobs: Number of parallel jobs for cross-validation, defaults to -1 (all CPUs).
87-
:param device: Device to run operations on, e.g., "cpu" or "cuda".
8863
:param seed: Random seed for reproducibility, defaults to 0.
89-
:param batch_size: Batch size for embedding generation, defaults to 32.
90-
:param max_length: Maximum sequence length for embedding, or None for default.
64+
:param embedder_batch_size: Batch size for embedding generation, defaults to 32.
65+
:param embedder_max_length: Maximum sequence length for embedding, or None for default.
66+
:param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
67+
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
9168
"""
9269
self.cv = cv
9370
self.n_jobs = n_jobs
94-
self.device = device
9571
self.seed = seed
9672
self.embedder_name = embedder_name
97-
self.batch_size = batch_size
98-
self.max_length = max_length
9973
self.clf_name = clf_name
10074
self.clf_args = clf_args or {}
75+
self.embedder_batch_size = embedder_batch_size
76+
self.embedder_max_length = embedder_max_length
77+
self.embedder_device = embedder_device
78+
self.embedder_use_cache = embedder_use_cache
10179

10280
@classmethod
10381
def from_context(
@@ -121,10 +99,11 @@ def from_context(
12199

122100
return cls(
123101
embedder_name=embedder_name,
124-
device=context.get_device(),
125102
seed=context.seed,
126-
batch_size=context.get_batch_size(),
127-
max_length=context.get_max_length(),
103+
embedder_device=context.get_device(),
104+
embedder_batch_size=context.get_batch_size(),
105+
embedder_max_length=context.get_max_length(),
106+
embedder_use_cache=context.get_use_cache(),
128107
clf_name=clf_name,
129108
clf_args=clf_args,
130109
)
@@ -144,14 +123,15 @@ def fit(
144123
self._multilabel = isinstance(labels[0], list)
145124

146125
embedder = Embedder(
147-
device=self.device,
126+
device=self.embedder_device,
148127
model_name_or_path=self.embedder_name,
149-
batch_size=self.batch_size,
150-
max_length=self.max_length,
128+
batch_size=self.embedder_batch_size,
129+
max_length=self.embedder_max_length,
130+
use_cache=self.embedder_use_cache,
151131
)
152132
features = embedder.embed(utterances)
153-
if AVAILIABLE_CLASSIFIERS.get(self.clf_name):
154-
base_clf = AVAILIABLE_CLASSIFIERS[self.clf_name](**self.clf_args)
133+
if AVAILABLE_CLASSIFIERS.get(self.clf_name):
134+
base_clf = AVAILABLE_CLASSIFIERS[self.clf_name](**self.clf_args)
155135
else:
156136
msg = f"Class {self.clf_name} does not exist in sklearn or does not have predict_proba method"
157137
logger.error(msg)
@@ -180,54 +160,3 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
180160
def clear_cache(self) -> None:
181161
"""Clear cached data in memory used by the embedder."""
182162
self._embedder.delete()
183-
184-
def dump(self, path: str) -> None:
185-
"""
186-
Save the SklearnScorer's metadata, classifier, and embedder to disk.
187-
188-
:param path: Path to the directory where assets will be dumped.
189-
"""
190-
self.metadata = SklearnScorerDumpDict(
191-
multilabel=self._multilabel,
192-
batch_size=self.batch_size,
193-
max_length=self.max_length,
194-
)
195-
196-
dump_dir = Path(path)
197-
198-
metadata_path = dump_dir / self.metadata_dict_name
199-
with metadata_path.open("w") as file:
200-
json.dump(self.metadata, file, indent=4)
201-
202-
# dump sklearn model
203-
clf_path = dump_dir / self.classifier_file_name
204-
joblib.dump(self._clf, clf_path)
205-
206-
# dump sentence transformer model
207-
self._embedder.dump(dump_dir / self.embedding_model_subdir)
208-
209-
def load(self, path: str) -> None:
210-
"""
211-
Load the SklearnScorer's metadata, classifier, and embedder from disk.
212-
213-
:param path: Path to the directory containing the dumped assets.
214-
"""
215-
dump_dir = Path(path)
216-
217-
metadata_path = dump_dir / self.metadata_dict_name
218-
with metadata_path.open() as file:
219-
metadata: SklearnScorerDumpDict = json.load(file)
220-
self._multilabel = metadata["multilabel"]
221-
222-
# load sklearn model
223-
clf_path = dump_dir / self.classifier_file_name
224-
self._clf = joblib.load(clf_path)
225-
226-
# load sentence transformer model
227-
embedder_dir = dump_dir / self.embedding_model_subdir
228-
self._embedder = Embedder(
229-
device=self.device,
230-
model_name_or_path=embedder_dir,
231-
batch_size=metadata["batch_size"],
232-
max_length=metadata["max_length"],
233-
)

0 commit comments

Comments
 (0)