Skip to content

Commit 1cc47f6

Browse files
committed
make sentence-transformers an optional dependency
1 parent 7c5c65f commit 1cc47f6

File tree

4 files changed

+29
-21
lines changed

4 files changed

+29
-21
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ sh = uv run --no-sync --frozen
88
.PHONY: install
99
install:
1010
rm -rf uv.lock
11-
uv sync --all-groups --extra catboost --extra peft
11+
uv sync --all-groups --extra catboost --extra peft --extra sentence-transformers --extra transformers
1212

1313
.PHONY: test
1414
test:

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ classifiers=[
3131
]
3232
requires-python = ">=3.10,<3.13"
3333
dependencies = [
34-
"sentence-transformers (>=3,<4)",
34+
"torch (>=2.0.0,<3.0.0)",
3535
"scikit-learn (>=1.5,<2.0)",
3636
"iterative-stratification (>=0.1.9)",
3737
"appdirs (>=1.4,<2.0)",
@@ -43,7 +43,6 @@ dependencies = [
4343
"datasets (>=3.2.0,<4.0.0)",
4444
"xxhash (>=3.5.0,<4.0.0)",
4545
"python-dotenv (>=1.0.1,<2.0.0)",
46-
"transformers[torch] (>=4.49.0,<5.0.0)",
4746
"aiometer (>=1.0.0,<2.0.0)",
4847
"aiofiles (>=24.1.0,<25.0.0)",
4948
"threadpoolctl (>=3.0.0,<4.0.0)",
@@ -52,7 +51,8 @@ dependencies = [
5251
[project.optional-dependencies]
5352
catboost = ["catboost (>=1.2.8,<2.0.0)"]
5453
peft = ["peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)"]
55-
transformers = ["transformers[torch] (>=4.49.0,<5.0.0)"]
54+
transformers = ["transformers (>=4.49.0,<5.0.0)"]
55+
sentence-transformers = ["sentence-transformers (>=3,<4)"]
5656
dspy = [
5757
"dspy (>=2.6.5,<3.0.0)",
5858
]

src/autointent/_wrappers/embedder/sentence_transformers.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
import numpy.typing as npt
1111
import torch
1212
from datasets import Dataset
13-
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
14-
from sentence_transformers.losses import BatchAllTripletLoss
15-
from sentence_transformers.training_args import BatchSamplers
1613
from sklearn.model_selection import train_test_split
1714

1815
from autointent._hash import Hasher
@@ -25,6 +22,7 @@
2522
from .utils import get_embeddings_path
2623

2724
if TYPE_CHECKING:
25+
from sentence_transformers import SentenceTransformer
2826
from transformers import TrainerCallback
2927

3028
logger = logging.getLogger(__name__)
@@ -51,6 +49,7 @@ class SentenceTransformerEmbeddingBackend(BaseEmbeddingBackend):
5149
"""SentenceTransformer-based embedding backend implementation."""
5250

5351
supports_training: bool = True
52+
_model: "SentenceTransformer | None"
5453

5554
def __init__(self, config: SentenceTransformerEmbeddingConfig) -> None:
5655
"""Initialize the SentenceTransformer backend.
@@ -59,7 +58,7 @@ def __init__(self, config: SentenceTransformerEmbeddingConfig) -> None:
5958
config: Configuration for SentenceTransformer embeddings.
6059
"""
6160
self.config = config
62-
self._model: SentenceTransformer | None = None
61+
self._model = None
6362
self._trained: bool = False
6463

6564
def clear_ram(self) -> None:
@@ -71,10 +70,12 @@ def clear_ram(self) -> None:
7170
self._model = None
7271
torch.cuda.empty_cache()
7372

74-
def _load_model(self) -> SentenceTransformer:
73+
def _load_model(self) -> "SentenceTransformer":
7574
"""Load sentence transformers model to device."""
7675
if self._model is None:
77-
res = SentenceTransformer(
76+
# Lazy import sentence-transformers
77+
st = require("sentence_transformers", extra="sentence-transformers")
78+
res = st.SentenceTransformer(
7879
self.config.model_name,
7980
device=self.config.device,
8081
prompts=self.config.get_prompt_config(),
@@ -231,16 +232,17 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
231232

232233
model = self._load_model()
233234

235+
# Lazy import sentence-transformers training components (only needed for fine-tuning)
236+
st = require("sentence_transformers", extra="sentence-transformers")
237+
transformers = require("transformers", extra="transformers")
238+
234239
x_train, x_val, y_train, y_val = train_test_split(utterances, labels, test_size=config.val_fraction)
235240
tr_ds = Dataset.from_dict({"text": x_train, "label": y_train})
236241
val_ds = Dataset.from_dict({"text": x_val, "label": y_val})
237242

238-
loss = BatchAllTripletLoss(model=model, margin=config.margin)
243+
loss = st.losses.BatchAllTripletLoss(model=model, margin=config.margin)
239244
with tempfile.TemporaryDirectory() as tmp_dir:
240-
# Lazy import transformers (only needed for fine-tuning)
241-
transformers = require("transformers", extra="transformers")
242-
243-
args = SentenceTransformerTrainingArguments(
245+
args = st.SentenceTransformerTrainingArguments(
244246
save_strategy="epoch",
245247
save_total_limit=1,
246248
output_dir=tmp_dir,
@@ -251,7 +253,7 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
251253
warmup_ratio=config.warmup_ratio,
252254
fp16=config.fp16,
253255
bf16=config.bf16,
254-
batch_sampler=BatchSamplers.NO_DUPLICATES,
256+
batch_sampler=st.training_args.BatchSamplers.NO_DUPLICATES,
255257
metric_for_best_model="eval_loss",
256258
load_best_model_at_end=True,
257259
eval_strategy="epoch",
@@ -263,7 +265,7 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
263265
early_stopping_threshold=config.early_stopping_threshold,
264266
)
265267
]
266-
trainer = SentenceTransformerTrainer(
268+
trainer = st.SentenceTransformerTrainer(
267269
model=model,
268270
args=args,
269271
train_dataset=tr_ds,

src/autointent/_wrappers/ranker.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,22 @@
1010
import logging
1111
from pathlib import Path
1212
from random import shuffle
13-
from typing import Any, Literal, TypedDict
13+
from typing import TYPE_CHECKING, Any, Literal, TypedDict
1414

1515
import joblib
1616
import numpy as np
1717
import numpy.typing as npt
18-
import sentence_transformers as st
1918
import torch
2019
from sklearn.linear_model import LogisticRegressionCV
2120
from torch import nn
2221

22+
from autointent._utils import require
2323
from autointent.configs import CrossEncoderConfig
2424
from autointent.custom_types import ListOfLabels, RerankedItem
2525

26+
if TYPE_CHECKING:
27+
import sentence_transformers as st
28+
2629
logger = logging.getLogger(__name__)
2730

2831

@@ -95,7 +98,7 @@ class Ranker:
9598
_metadata_file_name = "metadata.json"
9699
_classifier_file_name = "classifier.joblib"
97100
config: CrossEncoderConfig
98-
cross_encoder: st.CrossEncoder
101+
cross_encoder: "st.CrossEncoder"
99102

100103
def __init__(
101104
self,
@@ -110,12 +113,15 @@ def __init__(
110113
classifier_head: Optional pre-trained classifier head
111114
output_range: Range of the output probabilities ([0, 1] for sigmoid, [-1, 1] for tanh)
112115
"""
116+
# Lazy import sentence-transformers
117+
st = require("sentence_transformers", extra="sentence-transformers")
118+
113119
self.config = CrossEncoderConfig.from_search_config(cross_encoder_config)
114120
self.cross_encoder = st.CrossEncoder(
115121
self.config.model_name,
116122
trust_remote_code=self.config.trust_remote_code,
117123
device=self.config.device,
118-
max_length=self.config.tokenizer_config.max_length, # type: ignore[arg-type]
124+
max_length=self.config.tokenizer_config.max_length,
119125
)
120126
self._train_head = False
121127
self._clf = classifier_head

0 commit comments

Comments
 (0)