Skip to content

Commit 4fe3025

Browse files
feat: switch LocalEmbedder to sentence-transformers backend (#508)
1 parent f84e385 commit 4fe3025

File tree

6 files changed

+108
-69
lines changed

6 files changed

+108
-69
lines changed

packages/ragbits-core/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- Add new fusion strategies for the hybrid vector store: RRF and DBSF (#413)
88
- move sources from ragbits-document-search to ragbits-core (#496)
99
- adding connection check to Azure get_blob_service (#502)
10+
- modify LocalEmbedder to use sentence-transformers instead of torch (#508)
1011

1112
## 0.13.0 (2025-04-02)
1213
- Make the score in VectorStoreResult consistent (always bigger is better)

packages/ragbits-core/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ chroma = [
5151
"chromadb>=0.6.3,<1.0.0",
5252
]
5353
local = [
54+
"sentence-transformers>=4.0.2,<5.0.0",
5455
"torch>=2.2.1,<3.0.0",
5556
"transformers>=4.44.2,<5.0.0",
5657
"numpy>=1.26.0,<2.0.0"
Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
from collections.abc import Iterator
1+
from dataclasses import field
2+
from typing import Any
23

34
from ragbits.core.audit import trace
45
from ragbits.core.embeddings import Embedder
56
from ragbits.core.options import Options
67

78
try:
8-
import torch
9-
import torch.nn.functional as F
10-
from transformers import AutoModel, AutoTokenizer
9+
from sentence_transformers import SentenceTransformer
1110

1211
HAS_LOCAL_EMBEDDINGS = True
1312
except ImportError:
@@ -19,30 +18,31 @@ class LocalEmbedderOptions(Options):
1918
Dataclass that represents available call options for the LocalEmbedder client.
2019
"""
2120

22-
batch_size: int = 1
21+
encode_kwargs: dict = field(default_factory=dict)
2322

2423

2524
class LocalEmbedder(Embedder[LocalEmbedderOptions]):
2625
"""
2726
Class for interaction with any encoder available in HuggingFace.
2827
29-
Note: Local implementation is not dedicated for production. Use it only in experiments / evaluation
28+
Note: Local implementation is not dedicated for production. Use it only in experiments / evaluation.
3029
"""
3130

3231
options_cls = LocalEmbedderOptions
3332

3433
def __init__(
3534
self,
3635
model_name: str,
37-
api_key: str | None = None,
3836
default_options: LocalEmbedderOptions | None = None,
37+
**model_kwargs: Any, # noqa: ANN401
3938
) -> None:
40-
"""Constructs a new local LLM instance.
39+
"""
40+
Constructs a new local LLM instance.
4141
4242
Args:
4343
model_name: Name of the model to use.
44-
api_key: The API key for Hugging Face authentication.
4544
default_options: Default options for the embedding model.
45+
model_kwargs: Additional arguments to pass to the SentenceTransformer.
4646
4747
Raises:
4848
ImportError: If the 'local' extra requirements are not installed.
@@ -52,15 +52,12 @@ def __init__(
5252

5353
super().__init__(default_options=default_options)
5454

55-
self.hf_api_key = api_key
5655
self.model_name = model_name
57-
58-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59-
self.model = AutoModel.from_pretrained(self.model_name, token=self.hf_api_key).to(self.device)
60-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=self.hf_api_key)
56+
self.model = SentenceTransformer(self.model_name, **model_kwargs)
6157

6258
async def embed_text(self, data: list[str], options: LocalEmbedderOptions | None = None) -> list[list[float]]:
63-
"""Calls the appropriate encoder endpoint with the given data and options.
59+
"""
60+
Calls the appropriate encoder endpoint with the given data and options.
6461
6562
Args:
6663
data: List of strings to get embeddings for.
@@ -74,36 +71,7 @@ async def embed_text(self, data: list[str], options: LocalEmbedderOptions | None
7471
data=data,
7572
model_name=self.model_name,
7673
model_obj=repr(self.model),
77-
tokenizer=repr(self.tokenizer),
78-
device=self.device,
7974
options=merged_options.dict(),
8075
) as outputs:
81-
embeddings = []
82-
for batch in self._batch(data, merged_options.batch_size):
83-
batch_dict = self.tokenizer(
84-
batch,
85-
max_length=self.tokenizer.model_max_length,
86-
padding=True,
87-
truncation=True,
88-
return_tensors="pt",
89-
).to(self.device)
90-
with torch.no_grad():
91-
model_outputs = self.model(**batch_dict)
92-
batch_embeddings = self._average_pool(model_outputs.last_hidden_state, batch_dict["attention_mask"])
93-
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
94-
embeddings.extend(batch_embeddings.to("cpu").tolist())
95-
96-
torch.cuda.empty_cache()
97-
outputs.embeddings = embeddings
98-
return embeddings
99-
100-
@staticmethod
101-
def _batch(data: list[str], batch_size: int) -> Iterator[list[str]]:
102-
length = len(data)
103-
for ndx in range(0, length, batch_size):
104-
yield data[ndx : min(ndx + batch_size, length)]
105-
106-
@staticmethod
107-
def _average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
108-
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
109-
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
76+
outputs.embeddings = self.model.encode(data, **merged_options.encode_kwargs).tolist()
77+
return outputs.embeddings
File renamed without changes.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pickle
2+
3+
import numpy as np
4+
5+
from ragbits.core.embeddings.local import LocalEmbedder, LocalEmbedderOptions
6+
7+
8+
async def test_local_embedder_embed_text():
9+
embedder = LocalEmbedder("sentence-transformers/all-MiniLM-L6-v2")
10+
11+
result = await embedder.embed_text(["test text"])
12+
13+
# Check that embeddings have the expected shape
14+
assert len(result) == 1
15+
assert len(result[0]) == 384 # This dimension depends on the model
16+
17+
18+
async def test_local_embedder_with_custom_encode_kwargs():
19+
# Test with custom encode parameters
20+
embedder = LocalEmbedder(
21+
"BAAI/bge-small-en-v1.5",
22+
prompts={
23+
"classification": "Classify the following text: ",
24+
"retrieval": "Retrieve semantically similar text: ",
25+
"clustering": "Identify the topic or theme based on the text: ",
26+
},
27+
)
28+
options = LocalEmbedderOptions(encode_kwargs={"prompt_name": "retrieval"})
29+
result = await embedder.embed_text(["test text"], options=options)
30+
31+
assert len(result) == 1
32+
assert len(result[0]) > 0
33+
34+
embedder = LocalEmbedder("BAAI/bge-small-en-v1.5")
35+
result_no_prompt = await embedder.embed_text(["test text"])
36+
37+
# Check that the embeddings with custom prompt are different from the default ones
38+
assert not np.array_equal(result[0], result_no_prompt[0])
39+
40+
41+
def test_local_embedder_pickling():
42+
embedder = LocalEmbedder("sentence-transformers/all-MiniLM-L6-v2")
43+
pickled = pickle.dumps(embedder)
44+
unpickled = pickle.loads(pickled) # noqa: S301
45+
46+
assert isinstance(unpickled, LocalEmbedder)
47+
assert unpickled.model_name == "sentence-transformers/all-MiniLM-L6-v2"
48+
assert unpickled.default_options == embedder.default_options

uv.lock

Lines changed: 44 additions & 23 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)