Skip to content

Commit ae27c04

Browse files
committed
feat: update logregembedding
1 parent 5843692 commit ae27c04

File tree

1 file changed

+28
-53
lines changed

1 file changed

+28
-53
lines changed

autointent/modules/embedding/_retrieval.py

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,9 @@
1717
from autointent.modules.abc import EmbeddingModule
1818

1919

20-
class RetrievalMetadata(BaseMetadataDict):
21-
"""Metadata class for RetrievalEmbedding."""
22-
23-
db_dir: str
24-
batch_size: int
25-
max_length: int | None
26-
27-
2820
class LogRegMetadata(BaseMetadataDict):
2921
"""Metadata class for LogisticRegressionCV and LabelEncoder."""
3022

31-
db_dir: str
32-
batch_size: int
33-
max_length: int | None
3423
classes: list[str]
3524

3625

@@ -79,27 +68,26 @@ def __init__(
7968
k: int,
8069
embedder_name: str,
8170
cv: int = 3,
82-
db_dir: str | None = None,
8371
embedder_device: str = "cpu",
84-
batch_size: int = 32,
85-
max_length: int | None = None,
86-
embedder_use_cache: bool = False,
72+
embedder_batch_size: int = 32,
73+
embedder_max_length: int | None = None,
74+
embedder_use_cache: bool = True,
8775
) -> None:
8876
"""
89-
Initialize the RetrievalEmbedding.
77+
Initialize the LogRegEmbedding.
9078
79+
:param cv:
80+
:param k: Number of nearest neighbors to retrieve.
9181
:param embedder_name: Name of the embedder used for creating embeddings.
92-
:param db_dir: Path to the database directory. If None, defaults will be used.
9382
:param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
9483
:param batch_size: Batch size for embedding generation.
9584
:param max_length: Maximum sequence length for embeddings. None if not set.
9685
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
9786
"""
9887
self.embedder_name = embedder_name
9988
self.embedder_device = embedder_device
100-
self._db_dir = db_dir
101-
self.batch_size = batch_size
102-
self.max_length = max_length
89+
self.embedder_batch_size = embedder_batch_size
90+
self.embedder_max_length = embedder_max_length
10391
self.embedder_use_cache = embedder_use_cache
10492
self.cv = cv
10593

@@ -116,21 +104,25 @@ def from_context(
116104
"""
117105
Create a LogRegEmbedding instance using a Context object.
118106
107+
:param cv:
119108
:param context: The context containing configurations and utilities.
109+
:param k: Number of nearest neighbors to retrieve.
120110
:param embedder_name: Name of the embedder to use.
121111
:return: Initialized LogRegEmbedding instance.
122112
"""
123113
return cls(
124114
k=k,
125115
cv=cv,
126116
embedder_name=embedder_name,
127-
db_dir=str(context.get_db_dir()),
128117
embedder_device=context.get_device(),
129-
batch_size=context.get_batch_size(),
130-
max_length=context.get_max_length(),
118+
embedder_batch_size=context.get_batch_size(),
119+
embedder_max_length=context.get_max_length(),
131120
embedder_use_cache=context.get_use_cache(),
132121
)
133122

123+
def clear_cache(self) -> None:
124+
"""Clear cached data in memory used by the vector index."""
125+
134126
def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
135127
"""
136128
Train the logistic regression model using the provided utterances and labels.
@@ -140,23 +132,15 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
140132
"""
141133
self._multilabel = isinstance(labels[0], list)
142134

143-
self._vector_index = VectorIndex(
144-
self.embedder_name,
145-
self.embedder_device,
146-
self.embedder_batch_size,
147-
self.embedder_max_length,
148-
self.embedder_use_cache,
149-
)
150-
self._vector_index.add(utterances, labels)
151-
152135
self.embedder = Embedder(
153136
device=self.embedder_device,
154-
model_name=self.embedder_name,
155-
batch_size=self.batch_size,
156-
max_length=self.max_length,
137+
model_name_or_path=self.embedder_name,
138+
batch_size=self.embedder_batch_size,
139+
max_length=self.embedder_max_length,
157140
use_cache=self.embedder_use_cache,
158141
)
159142
embeddings = self.embedder.embed(utterances)
143+
160144
if self._multilabel:
161145
self.label_encoder = MultiLabelBinarizer()
162146
encoded_labels = self.label_encoder.fit_transform(labels)
@@ -209,42 +193,33 @@ def get_assets(self) -> RetrieverArtifact:
209193
"""
210194
return RetrieverArtifact(embedder_name=self.embedder_name)
211195

212-
def clear_cache(self) -> None:
213-
"""Clear cached data in memory used by the vector index."""
214-
self.vector_index.clear_ram()
215-
216-
def dump(self, path: str) -> None:
196+
def dump(self, path: Path) -> None:
217197
"""
218198
Save the module's metadata, classifier parameters, and label encoder to a specified directory.
219199
220200
:param path: Path to the directory where assets will be dumped.
221201
"""
222-
self.metadata = LogRegMetadata(
223-
batch_size=self.batch_size,
224-
max_length=self.max_length,
225-
db_dir=str(self.db_dir),
202+
metadata = LogRegMetadata(
226203
classes=self.label_encoder.classes_.tolist(),
227204
)
228205

229-
self._vector_index.dump(Path(path))
206+
path.mkdir(parents=True, exist_ok=True)
207+
with (path / self.metadata_dict_name).open("w") as file:
208+
json.dump(metadata, file, indent=4)
230209

231210
classifier_path = "classifier.joblib"
232-
joblib.dump(self.classifier, classifier_path)
211+
joblib.dump(self.classifier, path / classifier_path)
233212

234-
def load(self, path: str) -> None:
213+
def load(self, path: Path) -> None:
235214
"""
236215
Load the module's metadata and model parameters from a specified directory.
237216
238217
:param path: Path to the directory containing the dumped assets.
239218
"""
240-
dump_dir = Path(path)
241-
242-
with (dump_dir / self.metadata_dict_name).open() as file:
219+
with (path / self.metadata_dict_name).open() as file:
243220
self.metadata: LogRegMetadata = json.load(file)
244221

245-
self._vector_index = VectorIndex.load(Path(path))
246-
247-
classifier_path = dump_dir / "classifier.joblib"
222+
classifier_path = path / "classifier.joblib"
248223
self.classifier = joblib_load(classifier_path)
249224
self.label_encoder = LabelEncoder()
250225
self.label_encoder.classes_ = self.metadata["classes"]

0 commit comments

Comments
 (0)