1919
2020
2121if TYPE_CHECKING :
22- from mixedbread_ai .client import AsyncMixedbreadAI , MixedbreadAI
23- from mixedbread_ai .core import RequestOptions
22+ from mixedbread import AsyncMixedbread , Mixedbread
2423
2524
2625class MixedbreadAIEmbeddingConfig (EmbeddingConfig ):
@@ -44,31 +43,33 @@ class MixedbreadAIEmbeddingConfig(EmbeddingConfig):
4443 )
4544
4645 @requires_dependencies (
47- ["mixedbread_ai " ],
48- extras = "mixedbreadai" ,
46+ ["mixedbread " ],
47+ extras = "embed- mixedbreadai" ,
4948 )
50- def get_client (self ) -> "MixedbreadAI " :
49+ def get_client (self ) -> "Mixedbread " :
5150 """
5251 Create the Mixedbread AI client.
5352
5453 Returns:
55- MixedbreadAI : Initialized client.
54+ Mixedbread : Initialized client.
5655 """
57- from mixedbread_ai . client import MixedbreadAI
56+ from mixedbread import Mixedbread
5857
59- return MixedbreadAI (
58+ return Mixedbread (
6059 api_key = self .api_key .get_secret_value (),
60+ max_retries = MAX_RETRIES ,
6161 )
6262
6363 @requires_dependencies (
64- ["mixedbread_ai " ],
65- extras = "mixedbreadai" ,
64+ ["mixedbread " ],
65+ extras = "embed- mixedbreadai" ,
6666 )
67- def get_async_client (self ) -> "AsyncMixedbreadAI " :
68- from mixedbread_ai . client import AsyncMixedbreadAI
67+ def get_async_client (self ) -> "AsyncMixedbread " :
68+ from mixedbread import AsyncMixedbread
6969
70- return AsyncMixedbreadAI (
70+ return AsyncMixedbread (
7171 api_key = self .api_key .get_secret_value (),
72+ max_retries = MAX_RETRIES ,
7273 )
7374
7475
@@ -88,29 +89,20 @@ def get_exemplary_embedding(self) -> list[float]:
8889 return self .embed_query (query = "Q" )
8990
9091 @requires_dependencies (
91- ["mixedbread_ai " ],
92+ ["mixedbread " ],
9293 extras = "embed-mixedbreadai" ,
9394 )
94- def get_request_options (self ) -> "RequestOptions" :
95- from mixedbread_ai .core import RequestOptions
96-
97- return RequestOptions (
98- max_retries = MAX_RETRIES ,
99- timeout_in_seconds = TIMEOUT ,
100- additional_headers = {"User-Agent" : USER_AGENT },
101- )
102-
103- def get_client (self ) -> "MixedbreadAI" :
95+ def get_client (self ) -> "Mixedbread" :
10496 return self .config .get_client ()
10597
106- def embed_batch (self , client : "MixedbreadAI " , batch : list [str ]) -> list [list [float ]]:
107- response = client .embeddings (
98+ def embed_batch (self , client : "Mixedbread " , batch : list [str ]) -> list [list [float ]]:
99+ response = client .embed (
108100 model = self .config .embedder_model_name ,
101+ input = batch ,
109102 normalized = True ,
110103 encoding_format = ENCODING_FORMAT ,
111- truncation_strategy = TRUNCATION_STRATEGY ,
112- request_options = self .get_request_options (),
113- input = batch ,
104+ extra_headers = {"User-Agent" : USER_AGENT },
105+ timeout = TIMEOUT ,
114106 )
115107 return [datum .embedding for datum in response .data ]
116108
@@ -124,28 +116,19 @@ async def get_exemplary_embedding(self) -> list[float]:
124116 return await self .embed_query (query = "Q" )
125117
126118 @requires_dependencies (
127- ["mixedbread_ai " ],
119+ ["mixedbread " ],
128120 extras = "embed-mixedbreadai" ,
129121 )
130- def get_request_options (self ) -> "RequestOptions" :
131- from mixedbread_ai .core import RequestOptions
132-
133- return RequestOptions (
134- max_retries = MAX_RETRIES ,
135- timeout_in_seconds = TIMEOUT ,
136- additional_headers = {"User-Agent" : USER_AGENT },
137- )
138-
139- def get_client (self ) -> "AsyncMixedbreadAI" :
122+ def get_client (self ) -> "AsyncMixedbread" :
140123 return self .config .get_async_client ()
141124
142- async def embed_batch (self , client : "AsyncMixedbreadAI " , batch : list [str ]) -> list [list [float ]]:
143- response = await client .embeddings (
125+ async def embed_batch (self , client : "AsyncMixedbread " , batch : list [str ]) -> list [list [float ]]:
126+ response = await client .embed (
144127 model = self .config .embedder_model_name ,
128+ input = batch ,
145129 normalized = True ,
146130 encoding_format = ENCODING_FORMAT ,
147- truncation_strategy = TRUNCATION_STRATEGY ,
148- request_options = self .get_request_options (),
149- input = batch ,
131+ extra_headers = {"User-Agent" : USER_AGENT },
132+ timeout = TIMEOUT ,
150133 )
151134 return [datum .embedding for datum in response .data ]
0 commit comments