Skip to content

Commit 19a74f4

Browse files
committed
embedder_model -> _model
1 parent 0b0c1fa commit 19a74f4

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

src/autointent/_wrappers/embedder.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class Embedder:
7777
_weights_dir_name: str = "sentence_transformer"
7878
_dump_dir: Path | None = None
7979
_trained: bool = False
80+
_model: SentenceTransformer
8081

8182
def __init__(self, embedder_config: EmbedderConfig) -> None:
8283
"""Initialize the Embedder.
@@ -97,15 +98,15 @@ def _get_hash(self) -> int:
9798
commit_hash = _get_latest_commit_hash(self.config.model_name)
9899
hasher.update(commit_hash)
99100
else:
100-
self.embedding_model = self._load_model()
101-
for parameter in self.embedding_model.parameters():
101+
self._model = self._load_model()
102+
for parameter in self._model.parameters():
102103
hasher.update(parameter.detach().cpu().numpy())
103104
hasher.update(self.config.tokenizer_config.max_length)
104105
return hasher.intdigest()
105106

106107
def _load_model(self) -> SentenceTransformer:
107108
"""Load sentence transformers model to device."""
108-
if not hasattr(self, "embedding_model"):
109+
if not hasattr(self, "_model"):
109110
res = SentenceTransformer(
110111
self.config.model_name,
111112
device=self.config.device,
@@ -114,7 +115,7 @@ def _load_model(self) -> SentenceTransformer:
114115
trust_remote_code=self.config.trust_remote_code,
115116
)
116117
else:
117-
res = self.embedding_model
118+
res = self._model
118119
return res
119120

120121
def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFineTuningConfig) -> None:
@@ -133,7 +134,8 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
133134
logger.warning(msg)
134135
return
135136

136-
self._load_model()
137+
self._model = self._load_model()
138+
137139
if config.early_stopping:
138140
x_train, x_val, y_train, y_val = train_test_split(utterances, labels, test_size=0.1, random_state=42)
139141
tr_ds = Dataset.from_dict({"text": x_train, "label": y_train})
@@ -142,7 +144,7 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
142144
tr_ds = Dataset.from_dict({"text": utterances, "label": labels})
143145
val_ds = None
144146

145-
loss = BatchAllTripletLoss(model=self.embedding_model, margin=config.margin)
147+
loss = BatchAllTripletLoss(model=self._model, margin=config.margin)
146148
with tempfile.TemporaryDirectory() as tmp_dir:
147149
args = SentenceTransformerTrainingArguments(
148150
save_strategy="epoch",
@@ -169,7 +171,7 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
169171
)
170172
)
171173
trainer = SentenceTransformerTrainer(
172-
model=self.embedding_model,
174+
model=self._model,
173175
args=args,
174176
train_dataset=tr_ds,
175177
eval_dataset=val_ds,
@@ -181,7 +183,7 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
181183

182184
# use temporary path for re-usage
183185
model_path = str(Path(tempfile.mkdtemp("autointent_embedders")) / str(uuid4()))
184-
self.embedding_model.save(model_path)
186+
self._model.save(model_path)
185187
self.config.model_name = model_path
186188

187189
self._trained = True
@@ -190,8 +192,8 @@ def clear_ram(self) -> None:
190192
"""Move the embedding model to CPU and delete it from memory."""
191193
if hasattr(self, "embedding_model"):
192194
logger.debug("Clearing embedder %s from memory", self.config.model_name)
193-
self.embedding_model.cpu()
194-
del self.embedding_model
195+
self._model.cpu()
196+
del self._model
195197
torch.cuda.empty_cache()
196198

197199
def delete(self) -> None:
@@ -208,7 +210,7 @@ def dump(self, path: Path) -> None:
208210
"""
209211
if self._trained:
210212
model_path = str((path / self._weights_dir_name).resolve())
211-
self.embedding_model.save(model_path, create_model_card=False)
213+
self._model.save(model_path, create_model_card=False)
212214
self.config.model_name = model_path
213215

214216
self._dump_dir = path
@@ -248,6 +250,11 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
248250
Returns:
249251
A numpy array of embeddings.
250252
"""
253+
if len(utterances) == 0:
254+
msg = "Empty input"
255+
logger.error(msg)
256+
raise ValueError(msg)
257+
251258
prompt = self.config.get_prompt(task_type)
252259

253260
if self.config.use_cache:
@@ -263,7 +270,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
263270
logger.debug("loading embeddings from %s", str(embeddings_path))
264271
return np.load(embeddings_path) # type: ignore[no-any-return]
265272

266-
self.embedding_model = self._load_model()
273+
self._model = self._load_model()
267274

268275
logger.debug(
269276
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s, prompt=%s",
@@ -275,9 +282,9 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
275282
)
276283

277284
if self.config.tokenizer_config.max_length is not None:
278-
self.embedding_model.max_seq_length = self.config.tokenizer_config.max_length
285+
self._model.max_seq_length = self.config.tokenizer_config.max_length
279286

280-
embeddings = self.embedding_model.encode(
287+
embeddings = self._model.encode(
281288
utterances,
282289
convert_to_numpy=True,
283290
batch_size=self.config.batch_size,

0 commit comments

Comments
 (0)