1- import asyncio
21import os
32from dataclasses import dataclass
43from typing import TYPE_CHECKING
54
65from pydantic import Field , SecretStr
76
87from unstructured_ingest .embed .interfaces import (
9- EMBEDDINGS_KEY ,
108 AsyncBaseEmbeddingEncoder ,
119 BaseEmbeddingEncoder ,
1210 EmbeddingConfig ,
1311)
14- from unstructured_ingest .utils .data_prep import batch_generator
1512from unstructured_ingest .utils .dep_check import requires_dependencies
1613
1714USER_AGENT = "@mixedbread-ai/unstructured"
@@ -85,7 +82,7 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
8582
8683 def get_exemplary_embedding (self ) -> list [float ]:
8784 """Get an exemplary embedding to determine dimensions and unit vector status."""
88- return self ._embed ([ "Q" ])[ 0 ]
85+ return self .embed_query ( query = "Q" )
8986
9087 @requires_dependencies (
9188 ["mixedbread_ai" ],
@@ -100,59 +97,19 @@ def get_request_options(self) -> "RequestOptions":
10097 additional_headers = {"User-Agent" : USER_AGENT },
10198 )
10299
103- def _embed (self , texts : list [str ]) -> list [list [float ]]:
104- """
105- Embed a list of texts using the Mixedbread AI API.
106-
107- Args:
108- texts (list[str]): List of texts to embed.
109-
110- Returns:
111- list[list[float]]: List of embeddings.
112- """
113-
114- responses = []
115- client = self .config .get_client ()
116- for batch in batch_generator (texts , batch_size = self .config .batch_size or len (texts )):
117- response = client .embeddings (
118- model = self .config .embedder_model_name ,
119- normalized = True ,
120- encoding_format = ENCODING_FORMAT ,
121- truncation_strategy = TRUNCATION_STRATEGY ,
122- request_options = self .get_request_options (),
123- input = batch ,
124- )
125- responses .append (response )
126- return [item .embedding for response in responses for item in response .data ]
127-
128- def embed_documents (self , elements : list [dict ]) -> list [dict ]:
129- """
130- Embed a list of document elements.
131-
132- Args:
133- elements (list[Element]): List of document elements.
134-
135- Returns:
136- list[Element]: Elements with embeddings.
137- """
138- elements = elements .copy ()
139- elements_with_text = [e for e in elements if e .get ("text" )]
140- embeddings = self ._embed ([e ["text" ] for e in elements_with_text ])
141- for element , embedding in zip (elements_with_text , embeddings ):
142- element [EMBEDDINGS_KEY ] = embedding
143- return elements
144-
145- def embed_query (self , query : str ) -> list [float ]:
146- """
147- Embed a query string.
148-
149- Args:
150- query (str): Query string to embed.
151-
152- Returns:
153- list[float]: Embedding of the query.
154- """
155- return self ._embed ([query ])[0 ]
100+ def get_client (self ) -> "MixedbreadAI" :
101+ return self .config .get_client ()
102+
103+ def embed_batch (self , client : "MixedbreadAI" , batch : list [str ]) -> list [list [float ]]:
104+ response = client .embeddings (
105+ model = self .config .embedder_model_name ,
106+ normalized = True ,
107+ encoding_format = ENCODING_FORMAT ,
108+ truncation_strategy = TRUNCATION_STRATEGY ,
109+ request_options = self .get_request_options (),
110+ input = batch ,
111+ )
112+ return [datum .embedding for datum in response .data ]
156113
157114
158115@dataclass
@@ -162,8 +119,7 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
162119
163120 async def get_exemplary_embedding (self ) -> list [float ]:
164121 """Get an exemplary embedding to determine dimensions and unit vector status."""
165- embedding = await self ._embed (["Q" ])
166- return embedding [0 ]
122+ return await self .embed_query (query = "Q" )
167123
168124 @requires_dependencies (
169125 ["mixedbread_ai" ],
@@ -178,58 +134,16 @@ def get_request_options(self) -> "RequestOptions":
178134 additional_headers = {"User-Agent" : USER_AGENT },
179135 )
180136
181- async def _embed (self , texts : list [str ]) -> list [list [float ]]:
182- """
183- Embed a list of texts using the Mixedbread AI API.
184-
185- Args:
186- texts (list[str]): List of texts to embed.
187-
188- Returns:
189- list[list[float]]: List of embeddings.
190- """
191- client = self .config .get_async_client ()
192- tasks = []
193- for batch in batch_generator (texts , batch_size = self .config .batch_size or len (texts )):
194- tasks .append (
195- client .embeddings (
196- model = self .config .embedder_model_name ,
197- normalized = True ,
198- encoding_format = ENCODING_FORMAT ,
199- truncation_strategy = TRUNCATION_STRATEGY ,
200- request_options = self .get_request_options (),
201- input = batch ,
202- )
203- )
204- responses = await asyncio .gather (* tasks )
205- return [item .embedding for response in responses for item in response .data ]
206-
207- async def embed_documents (self , elements : list [dict ]) -> list [dict ]:
208- """
209- Embed a list of document elements.
210-
211- Args:
212- elements (list[Element]): List of document elements.
213-
214- Returns:
215- list[Element]: Elements with embeddings.
216- """
217- elements = elements .copy ()
218- elements_with_text = [e for e in elements if e .get ("text" )]
219- embeddings = await self ._embed ([e ["text" ] for e in elements_with_text ])
220- for element , embedding in zip (elements_with_text , embeddings ):
221- element [EMBEDDINGS_KEY ] = embedding
222- return elements
223-
224- async def embed_query (self , query : str ) -> list [float ]:
225- """
226- Embed a query string.
227-
228- Args:
229- query (str): Query string to embed.
230-
231- Returns:
232- list[float]: Embedding of the query.
233- """
234- embedding = await self ._embed ([query ])
235- return embedding [0 ]
137+ def get_client (self ) -> "AsyncMixedbreadAI" :
138+ return self .config .get_async_client ()
139+
140+ async def embed_batch (self , client : "AsyncMixedbreadAI" , batch : list [str ]) -> list [list [float ]]:
141+ response = await client .embeddings (
142+ model = self .config .embedder_model_name ,
143+ normalized = True ,
144+ encoding_format = ENCODING_FORMAT ,
145+ truncation_strategy = TRUNCATION_STRATEGY ,
146+ request_options = self .get_request_options (),
147+ input = batch ,
148+ )
149+ return [datum .embedding for datum in response .data ]
0 commit comments