1- import json
21import logging
3- from pathlib import Path
42from typing import Any
53
6- import joblib
74import numpy as np
85import numpy .typing as npt
96from sklearn .linear_model import LogisticRegression
129from typing_extensions import Self
1310
1411from autointent import Context , Embedder
15- from autointent .custom_types import BaseMetadataDict , LabelType
12+ from autointent .custom_types import LabelType
1613from autointent .modules .abc import ScoringModule
1714
1815logger = logging .getLogger (__name__ )
19- AVAILIABLE_CLASSIFIERS = {
16+ AVAILABLE_CLASSIFIERS = {
2017 name : class_
2118 for name , class_ in all_estimators (
2219 type_filter = [
3027}
3128
3229
33- class SklearnScorerDumpDict (BaseMetadataDict ):
34- """
35- Metadata for dumping the state of a SklearnScorer.
36-
37- :ivar multilabel: Whether the task is multilabel classification.
38- :ivar batch_size: Batch size used for embedding.
39- :ivar max_length: Maximum sequence length for embedding, or None if not specified.
40- """
41-
42- multilabel : bool
43- batch_size : int
44- max_length : int | None
45-
46-
4730class SklearnScorer (ScoringModule ):
4831 """
4932 Scoring module for classification using sklearn classifiers with implemented predict_proba() method.
5033
5134 This module uses embeddings generated from a transformer model to train
5235 chosen sklearn classifier for intent classification.
5336
54- :ivar classifier_file_name: Filename for saving the classifier to disk.
55- :ivar embedding_model_subdir: Directory for saving the embedding model.
56- :ivar precomputed_embeddings: Flag indicating if embeddings are precomputed.
57- :ivar db_dir: Path to the database directory.
5837 :ivar name: Name of the scorer, defaults to "linear".
5938 """
6039
61- classifier_file_name : str = "classifier.joblib"
62- embedding_model_subdir : str = "embedding_model"
63- precomputed_embeddings : bool = False
64- db_dir : str
6540 name = "sklearn"
6641
6742 def __init__ (
6843 self ,
6944 embedder_name : str ,
7045 clf_name : str ,
46+ embedder_batch_size : int = 32 ,
47+ embedder_max_length : int | None = None ,
48+ embedder_device : str = "cpu" ,
49+ embedder_use_cache : bool = True ,
7150 cv : int = 3 ,
7251 clf_args : dict [str , Any ] | None = None ,
7352 n_jobs : int = - 1 ,
74- device : str = "cpu" ,
7553 seed : int = 0 ,
76- batch_size : int = 32 ,
77- max_length : int | None = None ,
7854 ) -> None :
7955 """
8056 Initialize the SklearnScorer.
@@ -84,20 +60,22 @@ def __init__(
8460 :param cv: Number of cross-validation folds, defaults to 3.
8561 :param clf_args: dictionary with the chosen sklearn classifier arguments, defaults to {}.
8662 :param n_jobs: Number of parallel jobs for cross-validation, defaults to -1 (all CPUs).
87- :param device: Device to run operations on, e.g., "cpu" or "cuda".
8863 :param seed: Random seed for reproducibility, defaults to 0.
89- :param batch_size: Batch size for embedding generation, defaults to 32.
90- :param max_length: Maximum sequence length for embedding, or None for default.
64+ :param embedder_batch_size: Batch size for embedding generation, defaults to 32.
65+ :param embedder_max_length: Maximum sequence length for embedding, or None for default.
66+ :param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
67+ :param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
9168 """
9269 self .cv = cv
9370 self .n_jobs = n_jobs
94- self .device = device
9571 self .seed = seed
9672 self .embedder_name = embedder_name
97- self .batch_size = batch_size
98- self .max_length = max_length
9973 self .clf_name = clf_name
10074 self .clf_args = clf_args or {}
75+ self .embedder_batch_size = embedder_batch_size
76+ self .embedder_max_length = embedder_max_length
77+ self .embedder_device = embedder_device
78+ self .embedder_use_cache = embedder_use_cache
10179
10280 @classmethod
10381 def from_context (
@@ -121,10 +99,11 @@ def from_context(
12199
122100 return cls (
123101 embedder_name = embedder_name ,
124- device = context .get_device (),
125102 seed = context .seed ,
126- batch_size = context .get_batch_size (),
127- max_length = context .get_max_length (),
103+ embedder_device = context .get_device (),
104+ embedder_batch_size = context .get_batch_size (),
105+ embedder_max_length = context .get_max_length (),
106+ embedder_use_cache = context .get_use_cache (),
128107 clf_name = clf_name ,
129108 clf_args = clf_args ,
130109 )
@@ -144,14 +123,15 @@ def fit(
144123 self ._multilabel = isinstance (labels [0 ], list )
145124
146125 embedder = Embedder (
147- device = self .device ,
126+ device = self .embedder_device ,
148127 model_name_or_path = self .embedder_name ,
149- batch_size = self .batch_size ,
150- max_length = self .max_length ,
128+ batch_size = self .embedder_batch_size ,
129+ max_length = self .embedder_max_length ,
130+ use_cache = self .embedder_use_cache ,
151131 )
152132 features = embedder .embed (utterances )
153- if AVAILIABLE_CLASSIFIERS .get (self .clf_name ):
154- base_clf = AVAILIABLE_CLASSIFIERS [self .clf_name ](** self .clf_args )
133+ if AVAILABLE_CLASSIFIERS .get (self .clf_name ):
134+ base_clf = AVAILABLE_CLASSIFIERS [self .clf_name ](** self .clf_args )
155135 else :
156136 msg = f"Class { self .clf_name } does not exist in sklearn or does not have predict_proba method"
157137 logger .error (msg )
@@ -180,54 +160,3 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
180160 def clear_cache (self ) -> None :
181161 """Clear cached data in memory used by the embedder."""
182162 self ._embedder .delete ()
183-
184- def dump (self , path : str ) -> None :
185- """
186- Save the SklearnScorer's metadata, classifier, and embedder to disk.
187-
188- :param path: Path to the directory where assets will be dumped.
189- """
190- self .metadata = SklearnScorerDumpDict (
191- multilabel = self ._multilabel ,
192- batch_size = self .batch_size ,
193- max_length = self .max_length ,
194- )
195-
196- dump_dir = Path (path )
197-
198- metadata_path = dump_dir / self .metadata_dict_name
199- with metadata_path .open ("w" ) as file :
200- json .dump (self .metadata , file , indent = 4 )
201-
202- # dump sklearn model
203- clf_path = dump_dir / self .classifier_file_name
204- joblib .dump (self ._clf , clf_path )
205-
206- # dump sentence transformer model
207- self ._embedder .dump (dump_dir / self .embedding_model_subdir )
208-
209- def load (self , path : str ) -> None :
210- """
211- Load the SklearnScorer's metadata, classifier, and embedder from disk.
212-
213- :param path: Path to the directory containing the dumped assets.
214- """
215- dump_dir = Path (path )
216-
217- metadata_path = dump_dir / self .metadata_dict_name
218- with metadata_path .open () as file :
219- metadata : SklearnScorerDumpDict = json .load (file )
220- self ._multilabel = metadata ["multilabel" ]
221-
222- # load sklearn model
223- clf_path = dump_dir / self .classifier_file_name
224- self ._clf = joblib .load (clf_path )
225-
226- # load sentence transformer model
227- embedder_dir = dump_dir / self .embedding_model_subdir
228- self ._embedder = Embedder (
229- device = self .device ,
230- model_name_or_path = embedder_dir ,
231- batch_size = metadata ["batch_size" ],
232- max_length = metadata ["max_length" ],
233- )
0 commit comments