Skip to content

Commit f092999

Browse files
committed
resolve conflicts
2 parents c15914c + ad097e8 commit f092999

File tree

27 files changed

+143
-110
lines changed

27 files changed

+143
-110
lines changed

autointent/context/embedder.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import logging
3+
import shutil
34
from pathlib import Path
45
from typing import TypedDict
56

@@ -37,12 +38,19 @@ def __init__(
3738

3839
self.logger = logging.getLogger(__name__)
3940

40-
def delete(self) -> None:
41+
def clear_ram(self) -> None:
4142
self.logger.debug("deleting embedder %s", self.model_name)
4243
self.embedding_model.cpu()
4344
del self.embedding_model
4445

46+
def delete(self) -> None:
47+
self.clear_ram()
48+
shutil.rmtree(
49+
self.dump_dir, ignore_errors=True
50+
) # TODO: `ignore_errors=True` is workaround for PermissionError: [WinError 5] Access is denied
51+
4552
def dump(self, path: Path) -> None:
53+
self.dump_dir = path
4654
metadata = EmbedderDumpMetadata(
4755
batch_size=self.batch_size,
4856
max_length=self.max_length,
@@ -53,6 +61,7 @@ def dump(self, path: Path) -> None:
5361
json.dump(metadata, file, indent=4)
5462

5563
def load(self, path: Path | str) -> None:
64+
self.dump_dir = Path(path)
5665
path = Path(path)
5766
with (path / self.metadata_dict_name).open() as file:
5867
metadata: EmbedderDumpMetadata = json.load(file)
@@ -71,4 +80,9 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
7180
)
7281
if self.max_length is not None:
7382
self.embedding_model.max_seq_length = self.max_length
74-
return self.embedding_model.encode(utterances, convert_to_numpy=True, batch_size=self.batch_size) # type: ignore[return-value]
83+
return self.embedding_model.encode(
84+
utterances,
85+
convert_to_numpy=True,
86+
batch_size=self.batch_size,
87+
normalize_embeddings=True,
88+
) # type: ignore[return-value]

autointent/context/vector_index_client/cache.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,6 @@
33

44

55
def get_db_dir(db_dir: str | Path | None = None) -> Path:
6-
"""
7-
Get the directory path for chroma db file.
8-
Use default cache dir if not provided.
9-
Save path into user config in order to remove it from cache later.
10-
"""
11-
12-
root = Path(db_dir) if db_dir is not None else Path.cwd()
13-
db_dir = root / "vector_db" / str(uuid4()) if db_dir is None else Path(db_dir)
6+
db_dir = Path.cwd() / ("vector_db_" + str(uuid4())) if db_dir is None else Path(db_dir)
147
db_dir.mkdir(parents=True, exist_ok=True)
15-
168
return db_dir

autointent/context/vector_index_client/vector_index.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,18 @@ def is_empty(self) -> bool:
4747

4848
def delete(self) -> None:
4949
self.logger.debug("deleting vector index %s", self.model_name)
50-
if hasattr(self, "index"):
51-
self.index.reset()
50+
self.embedder.delete()
51+
self.clear_ram()
52+
(self.dump_dir / "index.faiss").unlink()
53+
(self.dump_dir / "texts.json").unlink()
54+
(self.dump_dir / "labels.json").unlink()
55+
56+
def clear_ram(self) -> None:
57+
self.logger.debug("clearing vector index %s from ram", self.model_name)
58+
self.index.reset()
5259
self.labels = []
5360
self.texts = []
5461

55-
self.embedder.delete()
56-
5762
def _search_by_text(self, texts: list[str], k: int) -> list[list[dict[str, Any]]]:
5863
query_embedding: npt.NDArray[np.float64] = self.embedder.embed(texts) # type: ignore[assignment]
5964
return self._search_by_embedding(query_embedding, k)
@@ -122,11 +127,8 @@ def dump(self, dir_path: Path) -> None:
122127
with (self.dump_dir / "labels.json").open("w") as file:
123128
json.dump(self.labels, file, indent=4, ensure_ascii=False)
124129

125-
def load(self, dir_path: Path | None = None) -> None:
126-
self.delete()
127-
128-
if dir_path is None:
129-
dir_path = self.dump_dir
130+
def load(self, dir_path: Path) -> None:
131+
self.dump_dir = Path(dir_path)
130132
self.index = faiss.read_index(str(dir_path / "index.faiss"))
131133
self.embedder = Embedder(model_name=dir_path / "embedding_model", device=self.device)
132134
with (dir_path / "texts.json").open() as file:

autointent/context/vector_index_client/vector_index_client.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import logging
3-
import shutil
43
from pathlib import Path
54

65
from autointent.custom_types import LabelType
@@ -87,10 +86,10 @@ def _get_dump_dirpath(self, model_name: str) -> Path:
8786
return self.db_dir / dir_name
8887

8988
def delete_index(self, model_name: str) -> None:
90-
dir_name = self._remove_index_dirname(model_name)
91-
if dir_name is not None:
92-
self._logger.debug("Deleting index for model: %s", model_name)
93-
shutil.rmtree(self.db_dir / dir_name)
89+
if not self.exists(model_name):
90+
return
91+
index = self.get_index(model_name)
92+
index.delete()
9493

9594
def get_index(self, model_name: str) -> VectorIndex:
9695
dirpath = self._get_index_dirpath(model_name)
@@ -107,7 +106,14 @@ def exists(self, model_name: str) -> bool:
107106
return self._get_index_dirpath(model_name) is not None
108107

109108
def delete_db(self) -> None:
110-
shutil.rmtree(self.db_dir)
109+
path = self.db_dir / "indexes_dirnames.json"
110+
if not path.exists():
111+
return
112+
with path.open() as file:
113+
indexes_dirnames: DIRNAMES_TYPE = json.load(file)
114+
for embedder_name in indexes_dirnames:
115+
self.delete_index(embedder_name)
116+
path.unlink()
111117

112118

113119
class NonExistingIndexError(Exception):

autointent/metrics/scoring.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __call__(self, labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> floa
2424
...
2525

2626

27-
def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
27+
def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE, eps: float = 1e-10) -> float:
2828
"""
2929
supports multiclass and multilabel
3030
@@ -45,9 +45,10 @@ def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE)
4545
where `s[i,c]` is a predicted score of `i`th utterance having ground truth label `c`
4646
"""
4747
labels_array, scores_array = transform(labels, scores)
48+
scores_array[scores_array == 0] = eps
4849

4950
if np.any((scores_array <= 0) | (scores_array > 1)):
50-
msg = "One or more scores are not from [0,1]. It is incompatible with `scoring_log_likelihood` metric"
51+
msg = "One or more scores are not from (0,1]. It is incompatible with `scoring_log_likelihood` metric"
5152
logger.error(msg)
5253
raise ValueError(msg)
5354

autointent/modules/retrieval/vectordb.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@ def __init__(
3232
batch_size: int = 32,
3333
max_length: int | None = None,
3434
) -> None:
35-
if db_dir is None:
36-
db_dir = str(get_db_dir())
3735
self.embedder_name = embedder_name
3836
self.device = device
39-
self.db_dir = db_dir
37+
self._db_dir = db_dir
4038
self.batch_size = batch_size
4139
self.max_length = max_length
4240

@@ -58,6 +56,12 @@ def from_context(
5856
max_length=context.get_max_length(),
5957
)
6058

59+
@property
60+
def db_dir(self) -> str:
61+
if self._db_dir is None:
62+
self._db_dir = str(get_db_dir())
63+
return self._db_dir
64+
6165
def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
6266
vector_index_client = VectorIndexClient(
6367
self.device, self.db_dir, embedder_batch_size=self.batch_size, embedder_max_length=self.max_length
@@ -76,7 +80,7 @@ def get_assets(self) -> RetrieverArtifact:
7680
return RetrieverArtifact(embedder_name=self.embedder_name)
7781

7882
def clear_cache(self) -> None:
79-
self.vector_index.delete()
83+
self.vector_index.clear_ram()
8084

8185
def dump(self, path: str) -> None:
8286
self.metadata = VectorDBMetadata(

autointent/modules/scoring/description/description.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from autointent.context import Context
1212
from autointent.context.embedder import Embedder
1313
from autointent.context.vector_index_client import VectorIndex, VectorIndexClient
14-
from autointent.context.vector_index_client.cache import get_db_dir
1514
from autointent.custom_types import LabelType
1615
from autointent.modules.scoring.base import ScoringModule
1716

@@ -30,22 +29,19 @@ class DescriptionScorer(ScoringModule):
3029
precomputed_embeddings: bool = False
3130
embedding_model_subdir: str = "embedding_model"
3231
_vector_index: VectorIndex
32+
db_dir: str
3333
name = "description"
3434

3535
def __init__(
3636
self,
3737
embedder_name: str,
38-
db_dir: Path | None = None,
3938
temperature: float = 1.0,
4039
device: str = "cpu",
4140
batch_size: int = 32,
4241
max_length: int | None = None,
4342
) -> None:
44-
if db_dir is None:
45-
db_dir = get_db_dir()
4643
self.temperature = temperature
4744
self.device = device
48-
self.db_dir = db_dir
4945
self.embedder_name = embedder_name
5046
self.batch_size = batch_size
5147
self.max_length = max_length
@@ -66,10 +62,10 @@ def from_context(
6662
instance = cls(
6763
temperature=temperature,
6864
device=context.get_device(),
69-
db_dir=context.get_db_dir(),
7065
embedder_name=embedder_name,
7166
)
7267
instance.precomputed_embeddings = precomputed_embeddings
68+
instance.db_dir = str(context.get_db_dir())
7369
return instance
7470

7571
def get_embedder_name(self) -> str:
@@ -127,7 +123,7 @@ def predict(self, utterances: list[str]) -> NDArray[np.float64]:
127123
return probabilites # type: ignore[no-any-return]
128124

129125
def clear_cache(self) -> None:
130-
self.embedder.delete()
126+
self.embedder.clear_ram()
131127

132128
def dump(self, path: str) -> None:
133129
self.metadata = DescriptionScorerDumpMetadata(

autointent/modules/scoring/dnnc/dnnc.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,21 @@ def __init__(
5252
batch_size: int = 32,
5353
max_length: int | None = None,
5454
) -> None:
55-
if db_dir is None:
56-
db_dir = str(get_db_dir())
57-
5855
self.cross_encoder_name = cross_encoder_name
5956
self.embedder_name = embedder_name
6057
self.k = k
6158
self.train_head = train_head
6259
self.device = device
63-
self.db_dir = db_dir
60+
self._db_dir = db_dir
6461
self.batch_size = batch_size
6562
self.max_length = max_length
6663

64+
@property
65+
def db_dir(self) -> str:
66+
if self._db_dir is None:
67+
self._db_dir = str(get_db_dir())
68+
return self._db_dir
69+
6770
@classmethod
6871
def from_context(
6972
cls,
@@ -175,7 +178,7 @@ def _build_result(self, scores: list[list[float]], labels: list[list[LabelType]]
175178
return build_result(np.array(scores), np.array(labels), n_classes)
176179

177180
def clear_cache(self) -> None:
178-
pass
181+
self.vector_index.clear_ram()
179182

180183
def dump(self, path: str) -> None:
181184
self.metadata = DNNCScorerDumpMetadata(

autointent/modules/scoring/dnnc/head_training.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from sentence_transformers import CrossEncoder
2323
from sklearn.linear_model import LogisticRegressionCV
24+
from typing_extensions import Self
2425

2526
from autointent.custom_types import LabelType
2627

@@ -133,7 +134,7 @@ def set_classifier(self, clf: LogisticRegressionCV) -> None:
133134
self._clf = clf
134135

135136
@classmethod
136-
def load(cls, path: str) -> "CrossEncoderWithLogreg":
137+
def load(cls, path: str) -> Self:
137138
dump_dir = Path(path)
138139

139140
# load sklearn model
@@ -144,7 +145,7 @@ def load(cls, path: str) -> "CrossEncoderWithLogreg":
144145
crossencoder_dir = str(dump_dir / "crossencoder")
145146
model = CrossEncoder(crossencoder_dir) # TODO control device
146147

147-
res = CrossEncoderWithLogreg(model)
148+
res = cls(model)
148149
res.set_classifier(clf)
149150

150151
return res

autointent/modules/scoring/knn/knn.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,20 @@ def __init__(
4949
- closest: each sample has a non zero weight iff is the closest sample of some class
5050
- `device`: str, something like "cuda:0" or "cuda:0,1,2", a device to store embedding function
5151
"""
52-
if db_dir is None:
53-
db_dir = str(get_db_dir())
5452
self.embedder_name = embedder_name
5553
self.k = k
5654
self.weights = weights
57-
self.db_dir = db_dir
55+
self._db_dir = db_dir
5856
self.device = device
5957
self.batch_size = batch_size
6058
self.max_length = max_length
6159

60+
@property
61+
def db_dir(self) -> str:
62+
if self._db_dir is None:
63+
self._db_dir = str(get_db_dir())
64+
return self._db_dir
65+
6266
@classmethod
6367
def from_context(
6468
cls,
@@ -118,7 +122,7 @@ def predict_with_metadata(
118122
return scores, metadata
119123

120124
def clear_cache(self) -> None:
121-
self._vector_index.delete()
125+
self._vector_index.clear_ram()
122126

123127
def dump(self, path: str) -> None:
124128
self.metadata = KNNScorerDumpMetadata(

0 commit comments

Comments
 (0)