1
- from collections .abc import Iterator
1
+ from dataclasses import field
2
+ from typing import Any
2
3
3
4
from ragbits .core .audit import trace
4
5
from ragbits .core .embeddings import Embedder
5
6
from ragbits .core .options import Options
6
7
7
8
try :
8
- import torch
9
- import torch .nn .functional as F
10
- from transformers import AutoModel , AutoTokenizer
9
+ from sentence_transformers import SentenceTransformer
11
10
12
11
HAS_LOCAL_EMBEDDINGS = True
13
12
except ImportError :
@@ -19,30 +18,31 @@ class LocalEmbedderOptions(Options):
19
18
Dataclass that represents available call options for the LocalEmbedder client.
20
19
"""
21
20
22
- batch_size : int = 1
21
+ encode_kwargs : dict = field ( default_factory = dict )
23
22
24
23
25
24
class LocalEmbedder (Embedder [LocalEmbedderOptions ]):
26
25
"""
27
26
Class for interaction with any encoder available in HuggingFace.
28
27
29
- Note: Local implementation is not dedicated for production. Use it only in experiments / evaluation
28
+ Note: Local implementation is not dedicated for production. Use it only in experiments / evaluation.
30
29
"""
31
30
32
31
options_cls = LocalEmbedderOptions
33
32
34
33
def __init__ (
35
34
self ,
36
35
model_name : str ,
37
- api_key : str | None = None ,
38
36
default_options : LocalEmbedderOptions | None = None ,
37
+ ** model_kwargs : Any , # noqa: ANN401
39
38
) -> None :
40
- """Constructs a new local LLM instance.
39
+ """
40
+ Constructs a new local LLM instance.
41
41
42
42
Args:
43
43
model_name: Name of the model to use.
44
- api_key: The API key for Hugging Face authentication.
45
44
default_options: Default options for the embedding model.
45
+ model_kwargs: Additional arguments to pass to the SentenceTransformer.
46
46
47
47
Raises:
48
48
ImportError: If the 'local' extra requirements are not installed.
@@ -52,15 +52,12 @@ def __init__(
52
52
53
53
super ().__init__ (default_options = default_options )
54
54
55
- self .hf_api_key = api_key
56
55
self .model_name = model_name
57
-
58
- self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
59
- self .model = AutoModel .from_pretrained (self .model_name , token = self .hf_api_key ).to (self .device )
60
- self .tokenizer = AutoTokenizer .from_pretrained (self .model_name , token = self .hf_api_key )
56
+ self .model = SentenceTransformer (self .model_name , ** model_kwargs )
61
57
62
58
async def embed_text (self , data : list [str ], options : LocalEmbedderOptions | None = None ) -> list [list [float ]]:
63
- """Calls the appropriate encoder endpoint with the given data and options.
59
+ """
60
+ Calls the appropriate encoder endpoint with the given data and options.
64
61
65
62
Args:
66
63
data: List of strings to get embeddings for.
@@ -74,36 +71,7 @@ async def embed_text(self, data: list[str], options: LocalEmbedderOptions | None
74
71
data = data ,
75
72
model_name = self .model_name ,
76
73
model_obj = repr (self .model ),
77
- tokenizer = repr (self .tokenizer ),
78
- device = self .device ,
79
74
options = merged_options .dict (),
80
75
) as outputs :
81
- embeddings = []
82
- for batch in self ._batch (data , merged_options .batch_size ):
83
- batch_dict = self .tokenizer (
84
- batch ,
85
- max_length = self .tokenizer .model_max_length ,
86
- padding = True ,
87
- truncation = True ,
88
- return_tensors = "pt" ,
89
- ).to (self .device )
90
- with torch .no_grad ():
91
- model_outputs = self .model (** batch_dict )
92
- batch_embeddings = self ._average_pool (model_outputs .last_hidden_state , batch_dict ["attention_mask" ])
93
- batch_embeddings = F .normalize (batch_embeddings , p = 2 , dim = 1 )
94
- embeddings .extend (batch_embeddings .to ("cpu" ).tolist ())
95
-
96
- torch .cuda .empty_cache ()
97
- outputs .embeddings = embeddings
98
- return embeddings
99
-
100
- @staticmethod
101
- def _batch (data : list [str ], batch_size : int ) -> Iterator [list [str ]]:
102
- length = len (data )
103
- for ndx in range (0 , length , batch_size ):
104
- yield data [ndx : min (ndx + batch_size , length )]
105
-
106
- @staticmethod
107
- def _average_pool (last_hidden_states : torch .Tensor , attention_mask : torch .Tensor ) -> torch .Tensor :
108
- last_hidden = last_hidden_states .masked_fill (~ attention_mask [..., None ].bool (), 0.0 )
109
- return last_hidden .sum (dim = 1 ) / attention_mask .sum (dim = 1 )[..., None ]
76
+ outputs .embeddings = self .model .encode (data , ** merged_options .encode_kwargs ).tolist ()
77
+ return outputs .embeddings
0 commit comments