Skip to content

Commit bdc1161

Browse files
committed
move batch_size to EmbedderFineTuningConfig
1 parent a96c27d commit bdc1161

File tree

3 files changed

+11
-21
lines changed

3 files changed

+11
-21
lines changed

autointent/_wrappers/embedder.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,25 +126,20 @@ def _load_model(self) -> None:
126126
similarity_fn_name=self.config.similarity_fn_name,
127127
trust_remote_code=self.config.trust_remote_code,
128128
)
129+
129130
def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTuningConfig) -> None:
130131
"""Train the embedding model."""
131132
self._load_model()
132133

133-
tr_ds = Dataset.from_dict({
134-
"text": utterances,
135-
"label": labels
136-
})
134+
tr_ds = Dataset.from_dict({"text": utterances, "label": labels})
137135

138-
loss = BatchAllTripletLoss(
139-
model=self.embedding_model,
140-
margin=config.margin
141-
)
136+
loss = BatchAllTripletLoss(model=self.embedding_model, margin=config.margin)
142137
with tempfile.TemporaryDirectory() as tmp_dir:
143138
args = SentenceTransformerTrainingArguments(
144139
save_strategy="no",
145140
output_dir=tmp_dir,
146141
num_train_epochs=config.epoch_num,
147-
per_device_train_batch_size=self.config.batch_size,
142+
per_device_train_batch_size=config.batch_size,
148143
learning_rate=config.learning_rate,
149144
warmup_ratio=config.warmup_ratio,
150145
fp16=config.fp16,

autointent/configs/_transformers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@ class TokenizerConfig(BaseModel):
1414
truncation: bool = True
1515
max_length: PositiveInt | None = Field(None, description="Maximum length of input sequences.")
1616

17+
1718
class EmbedderFineTuningConfig(BaseModel):
1819
epoch_num: int
20+
batch_size: int
1921
margin: float = Field(default=0.5)
2022
learning_rate: float = Field(default=2e-5)
2123
warmup_ratio: float = Field(default=0.1)
2224
fp16: bool = Field(default=True)
2325
bf16: bool = Field(default=False)
2426

27+
2528
class HFModelConfig(BaseModel):
2629
model_config = ConfigDict(extra="forbid")
2730
model_name: str = Field(

tests/embedder/test_fine_tuning.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@ def test_model_updates_after_training(dataset):
99
"""Test that model weights actually change after training"""
1010
data_handler = DataHandler(dataset)
1111

12-
hf_config = HFModelConfig(
13-
model_name="intfloat/multilingual-e5-small",
14-
batch_size=8,
15-
trust_remote_code=True
16-
)
12+
hf_config = HFModelConfig(model_name="intfloat/multilingual-e5-small", batch_size=8, trust_remote_code=True)
1713

1814
embedder_config = EmbedderConfig(
1915
**hf_config.model_dump(),
@@ -22,12 +18,10 @@ def test_model_updates_after_training(dataset):
2218
passage_prompt="Document:",
2319
similarity_fn_name="cosine",
2420
use_cache=False,
25-
freeze=False
21+
freeze=False,
2622
)
2723

28-
train_config = EmbedderFineTuningConfig(
29-
epoch_num = 1
30-
)
24+
train_config = EmbedderFineTuningConfig(epoch_num=1)
3125
embedder = Embedder(embedder_config)
3226
embedder._load_model()
3327

@@ -37,9 +31,7 @@ def test_model_updates_after_training(dataset):
3731
if param.requires_grad
3832
]
3933
embedder.train(
40-
utterances=data_handler.train_utterances(0)[:10],
41-
labels=data_handler.train_labels(0)[:10],
42-
config=train_config
34+
utterances=data_handler.train_utterances(0)[:10], labels=data_handler.train_labels(0)[:10], config=train_config
4335
)
4436

4537
trained_weights = [

0 commit comments

Comments
 (0)