Skip to content

Commit e13f171

Browse files
committed
deleted kwargs and local savings, added config
1 parent c5b5b2c commit e13f171

File tree

2 files changed

+32
-26
lines changed

2 files changed

+32
-26
lines changed

autointent/_wrappers/embedder.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,22 @@
1010
from functools import lru_cache
1111
from pathlib import Path
1212
from typing import TypedDict
13+
import tempfile
1314

1415
import huggingface_hub
1516
import numpy as np
1617
import numpy.typing as npt
1718
import torch
1819
from appdirs import user_cache_dir
19-
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, InputExample
20+
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
2021
from sentence_transformers.similarity_functions import SimilarityFunction
2122
from sentence_transformers.losses import BatchAllTripletLoss
2223
from sentence_transformers.training_args import BatchSamplers
2324
from datasets import Dataset
2425

2526

2627
from autointent._hash import Hasher
27-
from autointent.configs import EmbedderConfig, TaskTypeEnum
28+
from autointent.configs import EmbedderConfig, TaskTypeEnum, EmbedderFineTuningConfig
2829

2930
logger = logging.getLogger(__name__)
3031

@@ -126,7 +127,7 @@ def _load_model(self) -> None:
126127
similarity_fn_name=self.config.similarity_fn_name,
127128
trust_remote_code=self.config.trust_remote_code,
128129
)
129-
def train(self, utterances: list[str], labels: list[int], **kwargs) -> None:
130+
def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTuningConfig) -> None:
130131
"""Train the embedding model"""
131132
self._load_model()
132133

@@ -137,31 +138,29 @@ def train(self, utterances: list[str], labels: list[int], **kwargs) -> None:
137138

138139
loss = BatchAllTripletLoss(
139140
model=self.embedding_model,
140-
margin=kwargs.get("margin", 0.5)
141-
)
142-
143-
args = SentenceTransformerTrainingArguments(
144-
save_strategy="no",
145-
output_dir=kwargs['out_dir'],
146-
num_train_epochs=kwargs['epoch_num'],
147-
per_device_train_batch_size=self.config.batch_size,
148-
learning_rate=kwargs.get("learning_rate", 2e-5),
149-
warmup_ratio=kwargs.get("warmup_ratio", 0.1),
150-
fp16=kwargs.get("fp16", True),
151-
bf16=kwargs.get("bf16", False),
152-
batch_sampler=BatchSamplers.NO_DUPLICATES,
141+
margin=config.margin
153142
)
143+
with tempfile.TemporaryDirectory() as tmp_dir:
144+
args = SentenceTransformerTrainingArguments(
145+
save_strategy="no",
146+
output_dir=tmp_dir,
147+
num_train_epochs=config.epoch_num,
148+
per_device_train_batch_size=self.config.batch_size,
149+
learning_rate=config.learning_rate,
150+
warmup_ratio=config.warmup_ratio,
151+
fp16=config.fp16,
152+
bf16=config.bf16,
153+
batch_sampler=BatchSamplers.NO_DUPLICATES,
154+
)
154155

155-
trainer = SentenceTransformerTrainer(
156-
model=self.embedding_model,
157-
args=args,
158-
train_dataset=tr_ds,
159-
loss=loss,
160-
)
161-
162-
trainer.train()
163-
164-
self.embedding_model.save(kwargs['out_dir'])
156+
trainer = SentenceTransformerTrainer(
157+
model=self.embedding_model,
158+
args=args,
159+
train_dataset=tr_ds,
160+
loss=loss,
161+
)
162+
163+
trainer.train()
165164

166165
def clear_ram(self) -> None:
167166
"""Move the embedding model to CPU and delete it from memory."""

autointent/configs/_transformers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ class TokenizerConfig(BaseModel):
1414
truncation: bool = True
1515
max_length: PositiveInt | None = Field(None, description="Maximum length of input sequences.")
1616

17+
class EmbedderFineTuningConfig(BaseModel):
18+
epoch_num: int
19+
margin: float = Field(default=0.5)
20+
learning_rate: float = Field(default=2e-5)
21+
warmup_ratio: float = Field(default=0.1)
22+
fp16: bool = Field(default=True)
23+
bf16: bool = Field(default=False)
1724

1825
class HFModelConfig(BaseModel):
1926
model_config = ConfigDict(extra="forbid")

0 commit comments

Comments
 (0)