Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4caced7
add train
k0lenk4 Jul 3, 2025
c5b5b2c
fixed env
k0lenk4 Jul 3, 2025
e13f171
deleted kwargs and local savings, added config
k0lenk4 Jul 28, 2025
d26eda0
added test for train method
k0lenk4 Jul 28, 2025
ee1b4a1
add EmbedderFineTuningConfig to __init__
k0lenk4 Aug 2, 2025
3052628
correct __init__ in config, remov pytest in test file
k0lenk4 Aug 2, 2025
a96c27d
correct some syntax isues
k0lenk4 Aug 2, 2025
bdc1161
move batch_size to EmbedderFineTuningConfig
k0lenk4 Aug 11, 2025
d71de34
add __init__.py to /test/embedder
k0lenk4 Aug 11, 2025
941c13a
Remove whitespace from blank line
k0lenk4 Aug 11, 2025
3ecdc60
Merge remote-tracking branch 'origin/dev' into feat/train-embeddings
k0lenk4 Aug 11, 2025
0739413
correct errors
k0lenk4 Aug 11, 2025
1e161c6
the number of epochs and train objects have been increased
k0lenk4 Aug 11, 2025
e67f1bc
made lint
k0lenk4 Aug 11, 2025
3c38ec8
add early stopping
k0lenk4 Aug 12, 2025
c743c0b
remake train args
k0lenk4 Aug 12, 2025
71bf957
make a list of callbacks
k0lenk4 Aug 12, 2025
2963a4c
inline type annotation of variable "callback"
k0lenk4 Aug 12, 2025
03e4c59
change save_strategy to "epoch"
k0lenk4 Aug 12, 2025
714f910
default value of fp16 changed to False
k0lenk4 Aug 15, 2025
cb9b2ea
pull dev
voorhs Aug 18, 2025
6aa7abc
integrate embeddings fine-tuning into Embedding modules
voorhs Aug 19, 2025
b88a810
pull dev
voorhs Aug 19, 2025
2970737
Update optimizer_config.schema.json
github-actions[bot] Aug 19, 2025
fcf1f31
clean up `freeze` throughout tests and tutorials
voorhs Aug 19, 2025
0b0c1fa
add comprehensive tests
voorhs Aug 19, 2025
19a74f4
embedder_model -> _model
voorhs Aug 19, 2025
1d73af6
fix early stopping
voorhs Aug 25, 2025
714c8c2
fix tests
voorhs Aug 25, 2025
ebf066b
clear ram bug fix
voorhs Aug 25, 2025
51a9b1a
try to fix windows cleanup issue
voorhs Aug 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,5 @@ vector_db*
*.db
*.sqlite
/wandb
model_output/
my.py
2 changes: 1 addition & 1 deletion autointent/_dump_tools/unit_dumpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from peft import PeftModel
from pydantic import BaseModel
from sklearn.base import BaseEstimator
from transformers import ( # type: ignore[attr-defined]
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
PreTrainedModel,
Expand Down
59 changes: 57 additions & 2 deletions autointent/_wrappers/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
import shutil
import tempfile
from functools import lru_cache
from pathlib import Path
from typing import TypedDict
Expand All @@ -16,11 +17,16 @@
import numpy.typing as npt
import torch
from appdirs import user_cache_dir
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import BatchAllTripletLoss
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.training_args import BatchSamplers
from sklearn.model_selection import train_test_split
from transformers import EarlyStoppingCallback, TrainerCallback

from autointent._hash import Hasher
from autointent.configs import EmbedderConfig, TaskTypeEnum
from autointent.configs import EmbedderConfig, EmbedderFineTuningConfig, TaskTypeEnum

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,6 +129,55 @@ def _load_model(self) -> None:
trust_remote_code=self.config.trust_remote_code,
)

def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTuningConfig) -> None:
"""Train the embedding model."""
self._load_model()
if config.early_stopping:
x_train, x_val, y_train, y_val = train_test_split(utterances, labels, test_size=0.1, random_state=42)
tr_ds = Dataset.from_dict({"text": x_train, "label": y_train})
val_ds = Dataset.from_dict({"text": x_val, "label": y_val})
else:
tr_ds = Dataset.from_dict({"text": utterances, "label": labels})
val_ds = None

loss = BatchAllTripletLoss(model=self.embedding_model, margin=config.margin)
with tempfile.TemporaryDirectory() as tmp_dir:
args = SentenceTransformerTrainingArguments(
save_strategy="epoch",
output_dir=tmp_dir,
num_train_epochs=config.epoch_num,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=8,
eval_steps=1,
learning_rate=config.learning_rate,
warmup_ratio=config.warmup_ratio,
fp16=config.fp16,
bf16=config.bf16,
batch_sampler=BatchSamplers.NO_DUPLICATES,
metric_for_best_model="eval_loss",
load_best_model_at_end=True,
eval_strategy="epoch",
greater_is_better=False,
)
callback: list[TrainerCallback] = []
if config.early_stopping:
callback.append(
EarlyStoppingCallback(
early_stopping_patience=config.early_stopping,
early_stopping_threshold=config.early_stopping_threshold,
)
)
trainer = SentenceTransformerTrainer(
model=self.embedding_model,
args=args,
train_dataset=tr_ds,
eval_dataset=val_ds,
loss=loss,
callbacks=callback,
)

trainer.train()

def clear_ram(self) -> None:
"""Move the embedding model to CPU and delete it from memory."""
if hasattr(self, "embedding_model"):
Expand Down
2 changes: 2 additions & 0 deletions autointent/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CrossEncoderConfig,
EarlyStoppingConfig,
EmbedderConfig,
EmbedderFineTuningConfig,
HFModelConfig,
TaskTypeEnum,
TokenizerConfig,
Expand All @@ -17,6 +18,7 @@
"DataConfig",
"EarlyStoppingConfig",
"EmbedderConfig",
"EmbedderFineTuningConfig",
"HFModelConfig",
"HPOConfig",
"InferenceNodeConfig",
Expand Down
12 changes: 12 additions & 0 deletions autointent/configs/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ class TokenizerConfig(BaseModel):
max_length: PositiveInt | None = Field(None, description="Maximum length of input sequences.")


class EmbedderFineTuningConfig(BaseModel):
epoch_num: int
batch_size: int
margin: float = Field(default=0.5)
learning_rate: float = Field(default=2e-5)
warmup_ratio: float = Field(default=0.1)
early_stopping: bool = Field(default=True)
early_stopping_threshold: float = Field(default=0.0)
fp16: bool = Field(default=True)
bf16: bool = Field(default=False)


class HFModelConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
model_name: str = Field(
Expand Down
2 changes: 1 addition & 1 deletion autointent/context/data_handler/_stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from numpy import typing as npt
from sklearn.model_selection import train_test_split
from skmultilearn.model_selection import IterativeStratification
from transformers import set_seed # type: ignore[attr-defined]
from transformers import set_seed

from autointent import Dataset
from autointent.custom_types import LabelType
Expand Down
2 changes: 1 addition & 1 deletion autointent/modules/scoring/_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from transformers import ( # type: ignore[attr-defined]
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@

llms_txt_exclude = ["autoapi*"]


def setup(app: Sphinx) -> None:
generate_versions_json(repo_root, BASE_URL)
user_guids_dir = app.srcdir / "user_guides"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers=[
'Framework :: Sphinx',
'Typing :: Typed',
]
requires-python = ">=3.10,<4.0"
requires-python = ">=3.10,<3.13"
dependencies = [
"sentence-transformers (>=3,<4)",
"scikit-learn (>=1.5,<2.0)",
Expand Down
Empty file added tests/embedder/__init__.py
Empty file.
52 changes: 52 additions & 0 deletions tests/embedder/test_fine_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np

from autointent._wrappers.embedder import Embedder
from autointent.configs._transformers import EmbedderConfig, EmbedderFineTuningConfig, HFModelConfig
from autointent.context.data_handler import DataHandler


def test_model_updates_after_training(dataset):
"""Test that model weights actually change after training"""
data_handler = DataHandler(dataset)

hf_config = HFModelConfig(model_name="intfloat/multilingual-e5-small", batch_size=8, trust_remote_code=True)

embedder_config = EmbedderConfig(
**hf_config.model_dump(),
default_prompt="Represent this text for retrieval:",
query_prompt="Search query:",
passage_prompt="Document:",
similarity_fn_name="cosine",
use_cache=False,
freeze=False,
)

train_config = EmbedderFineTuningConfig(epoch_num=3, batch_size=8)
embedder = Embedder(embedder_config)
embedder._load_model()

for param in embedder.embedding_model.parameters():
assert param.requires_grad, "All trainable parameters should have requires_grad=True"

original_weights = [
param.data.detach().cpu().numpy().copy()
for param in embedder.embedding_model.parameters()
if param.requires_grad
]
embedder.train(
utterances=data_handler.train_utterances(0)[:1000],
labels=data_handler.train_labels(0)[:1000],
config=train_config,
)

trained_weights = [
param.data.detach().cpu().numpy().copy()
for param in embedder.embedding_model.parameters()
if param.requires_grad
]

weights_changed = any(
not np.allclose(orig, trained, atol=1e-6)
for orig, trained in zip(original_weights, trained_weights, strict=True)
)
assert weights_changed, "Model weights should change after training"
Loading