Skip to content

Commit 57eafbb

Browse files
authored
fix caching (#225)
1 parent fe2b79e commit 57eafbb

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

autointent/_embedder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,20 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
188188
Returns:
189189
A numpy array of embeddings.
190190
"""
191+
prompt = self.config.get_prompt(task_type)
192+
191193
if self.config.use_cache:
192194
hasher = Hasher()
193195
hasher.update(self)
194196
hasher.update(utterances)
197+
if prompt:
198+
hasher.update(prompt)
195199

196200
embeddings_path = _get_embeddings_path(hasher.hexdigest())
197201
if embeddings_path.exists():
198202
return np.load(embeddings_path) # type: ignore[no-any-return]
199203

200204
self._load_model()
201-
prompt = self.config.get_prompt(task_type)
202205

203206
logger.debug(
204207
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s, prompt=%s",

0 commit comments

Comments
 (0)