1010from sklearn .metrics .pairwise import cosine_similarity
1111
1212from autointent import Context , Embedder
13- from autointent .context .vector_index_client import VectorIndex , VectorIndexClient
1413from autointent .custom_types import LabelType
1514from autointent .modules .abc import ScoringModule
1615
1716
1817class DescriptionScorerDumpMetadata (TypedDict ):
1918 """Metadata for dumping the state of a DescriptionScorer."""
2019
21- db_dir : str
2220 n_classes : int
2321 multilabel : bool
24- batch_size : int
25- max_length : int | None
22+ embedder_batch_size : int
23+ embedder_max_length : int | None
2624
2725
2826class DescriptionScorer (ScoringModule ):
@@ -34,46 +32,40 @@ class DescriptionScorer(ScoringModule):
3432
3533 :ivar weights_file_name: Filename for saving the description vectors (`description_vectors.npy`).
3634 :ivar embedder: The embedder used to generate embeddings for utterances and descriptions.
37- :ivar precomputed_embeddings: Flag indicating whether precomputed embeddings are used.
3835 :ivar embedding_model_subdir: Directory for storing the embedder's model files.
39- :ivar _vector_index: Internal vector index used when embeddings are precomputed.
40- :ivar db_dir: Directory path where the vector database is stored.
4136 :ivar name: Name of the scorer, defaults to "description".
4237
4338 """
4439
4540 weights_file_name : str = "description_vectors.npy"
4641 embedder : Embedder
47- precomputed_embeddings : bool = False
4842 embedding_model_subdir : str = "embedding_model"
49- _vector_index : VectorIndex
50- db_dir : str
5143 name = "description"
5244
5345 def __init__ (
5446 self ,
5547 embedder_name : str ,
5648 temperature : float = 1.0 ,
5749 embedder_device : str = "cpu" ,
58- batch_size : int = 32 ,
59- max_length : int | None = None ,
60- embedder_use_cache : bool = False ,
50+ embedder_batch_size : int = 32 ,
51+ embedder_max_length : int | None = None ,
52+ embedder_use_cache : bool = True ,
6153 ) -> None :
6254 """
6355 Initialize the DescriptionScorer.
6456
6557 :param embedder_name: Name of the embedder model.
6658 :param temperature: Temperature parameter for scaling logits, defaults to 1.0.
6759 :param embedder_device: Device to run the embedder on, e.g., "cpu" or "cuda".
68- :param batch_size : Batch size for embedding generation, defaults to 32.
69- :param max_length : Maximum sequence length for embedding, defaults to None.
60+ :param embedder_batch_size : Batch size for embedding generation, defaults to 32.
61+ :param embedder_max_length : Maximum sequence length for embedding, defaults to None.
7062 :param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
7163 """
7264 self .temperature = temperature
7365 self .embedder_device = embedder_device
7466 self .embedder_name = embedder_name
75- self .batch_size = batch_size
76- self .max_length = max_length
67+ self .embedder_batch_size = embedder_batch_size
68+ self .embedder_max_length = embedder_max_length
7769 self .embedder_use_cache = embedder_use_cache
7870
7971 @classmethod
@@ -93,19 +85,15 @@ def from_context(
9385 """
9486 if embedder_name is None :
9587 embedder_name = context .optimization_info .get_best_embedder ()
96- precomputed_embeddings = True
97- else :
98- precomputed_embeddings = context .vector_index_client .exists (embedder_name )
9988
100- instance = cls (
89+ return cls (
10190 temperature = temperature ,
10291 embedder_device = context .get_device (),
10392 embedder_name = embedder_name ,
10493 embedder_use_cache = context .get_use_cache (),
94+ embedder_batch_size = context .get_batch_size (),
95+ embedder_max_length = context .get_max_length (),
10596 )
106- instance .precomputed_embeddings = precomputed_embeddings
107- instance .db_dir = str (context .get_db_dir ())
108- return instance
10997
11098 def get_embedder_name (self ) -> str :
11199 """
@@ -136,39 +124,22 @@ def fit(
136124 self .n_classes = len (set (labels ))
137125 self .multilabel = False
138126
139- if self .precomputed_embeddings :
140- # this happens only when LinearScorer is within Pipeline opimization after RetrievalNode optimization
141- vector_index_client = VectorIndexClient (
142- self .embedder_device ,
143- self .db_dir ,
144- self .batch_size ,
145- self .max_length ,
146- self .embedder_use_cache ,
147- )
148- vector_index = vector_index_client .get_index (self .embedder_name )
149- features = vector_index .get_all_embeddings ()
150- if len (features ) != len (utterances ):
151- msg = "Vector index mismatches provided utterances"
152- raise ValueError (msg )
153- embedder = vector_index .embedder
154- else :
155- embedder = Embedder (
156- device = self .embedder_device ,
157- model_name = self .embedder_name ,
158- batch_size = self .batch_size ,
159- max_length = self .max_length ,
160- use_cache = self .embedder_use_cache ,
161- )
162- features = embedder .embed (utterances )
163-
164127 if any (description is None for description in descriptions ):
165128 error_text = (
166129 "Some intent descriptions (label_description) are missing (None). "
167130 "Please ensure all intents have descriptions."
168131 )
169132 raise ValueError (error_text )
170133
171- self .description_vectors = embedder .embed ([desc for desc in descriptions if desc ])
134+ embedder = Embedder (
135+ device = self .embedder_device ,
136+ model_name = self .embedder_name ,
137+ batch_size = self .embedder_batch_size ,
138+ max_length = self .embedder_max_length ,
139+ use_cache = self .embedder_use_cache ,
140+ )
141+
142+ self .description_vectors = embedder .embed (descriptions )
172143 self .embedder = embedder
173144
174145 def predict (self , utterances : list [str ]) -> NDArray [np .float64 ]:
@@ -198,11 +169,10 @@ def dump(self, path: str) -> None:
198169 :param path: Path to the directory where assets will be dumped.
199170 """
200171 self .metadata = DescriptionScorerDumpMetadata (
201- db_dir = str (self .db_dir ),
202172 n_classes = self .n_classes ,
203173 multilabel = self .multilabel ,
204- batch_size = self .batch_size ,
205- max_length = self .max_length ,
174+ embedder_batch_size = self .embedder_batch_size ,
175+ embedder_max_length = self .embedder_max_length ,
206176 )
207177
208178 dump_dir = Path (path )
@@ -232,7 +202,7 @@ def load(self, path: str) -> None:
232202 self .embedder = Embedder (
233203 device = self .embedder_device ,
234204 model_name = embedder_dir ,
235- batch_size = self .metadata ["batch_size " ],
236- max_length = self .metadata ["max_length " ],
205+ batch_size = self .metadata ["embedder_batch_size " ],
206+ max_length = self .metadata ["embedder_max_length " ],
237207 use_cache = self .embedder_use_cache ,
238208 )
0 commit comments