Skip to content

Commit 442dbfa

Browse files
authored
feat/refactor embedders (#383)
* refactor embedders * fix mixedbreadai embedder * create default behavior for embed_query * fix default behavior of _embed_query
1 parent 17a5dd3 commit 442dbfa

File tree

13 files changed

+204
-403
lines changed

13 files changed

+204
-403
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
## 0.5.3-dev0
1+
## 0.5.3-dev1
22

33
### Enhancements
44

5+
* **Optimize embedder code** - Move duplicate code to base interface, exit early if no elements have text.
6+
57
### Fixes
68

79
## 0.5.2

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.3-dev0" # pragma: no cover
1+
__version__ = "0.5.3-dev1" # pragma: no cover

unstructured_ingest/embed/azure_openai.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ def get_async_client(self) -> "AsyncAzureOpenAI":
4444
class AzureOpenAIEmbeddingEncoder(OpenAIEmbeddingEncoder):
4545
config: AzureOpenAIEmbeddingConfig
4646

47+
def get_client(self) -> "AzureOpenAI":
48+
return self.config.get_client()
49+
4750

4851
@dataclass
4952
class AsyncAzureOpenAIEmbeddingEncoder(AsyncOpenAIEmbeddingEncoder):
5053
config: AzureOpenAIEmbeddingConfig
54+
55+
def get_client(self) -> "AsyncAzureOpenAI":
56+
return self.config.get_async_client()

unstructured_ingest/embed/bedrock.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
)
1616
from unstructured_ingest.logger import logger
1717
from unstructured_ingest.utils.dep_check import requires_dependencies
18-
from unstructured_ingest.v2.errors import ProviderError, RateLimitError, UserAuthError, UserError
18+
from unstructured_ingest.v2.errors import (
19+
ProviderError,
20+
RateLimitError,
21+
UserAuthError,
22+
UserError,
23+
is_internal_error,
24+
)
1925

2026
if TYPE_CHECKING:
2127
from botocore.client import BaseClient
@@ -54,6 +60,8 @@ class BedrockEmbeddingConfig(EmbeddingConfig):
5460
embed_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
5561

5662
def wrap_error(self, e: Exception) -> Exception:
63+
if is_internal_error(e=e):
64+
return e
5765
from botocore.exceptions import ClientError
5866

5967
if isinstance(e, ClientError):
@@ -148,6 +156,8 @@ def embed_query(self, query: str) -> list[float]:
148156
def embed_documents(self, elements: list[dict]) -> list[dict]:
149157
elements = elements.copy()
150158
elements_with_text = [e for e in elements if e.get("text")]
159+
if not elements_with_text:
160+
return elements
151161
embeddings = [self.embed_query(query=e["text"]) for e in elements_with_text]
152162
for element, embedding in zip(elements_with_text, embeddings):
153163
element[EMBEDDINGS_KEY] = embedding

unstructured_ingest/embed/huggingface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_encoder_kwargs(self) -> dict:
4747
class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
4848
config: HuggingFaceEmbeddingConfig
4949

50-
def embed_query(self, query: str) -> list[float]:
50+
def _embed_query(self, query: str) -> list[float]:
5151
return self._embed_documents(texts=[query])[0]
5252

5353
def _embed_documents(self, texts: list[str]) -> list[list[float]]:
@@ -58,6 +58,8 @@ def _embed_documents(self, texts: list[str]) -> list[list[float]]:
5858
def embed_documents(self, elements: list[dict]) -> list[dict]:
5959
elements = elements.copy()
6060
elements_with_text = [e for e in elements if e.get("text")]
61+
if not elements_with_text:
62+
return elements
6163
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
6264
for element, embedding in zip(elements_with_text, embeddings):
6365
element[EMBEDDINGS_KEY] = embedding

unstructured_ingest/embed/interfaces.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import asyncio
2-
from abc import ABC, abstractmethod
1+
from abc import ABC
32
from dataclasses import dataclass
4-
from typing import Optional
3+
from typing import Any, Optional
54

65
import numpy as np
76
from pydantic import BaseModel, Field
87

8+
from unstructured_ingest.utils.data_prep import batch_generator
9+
910
EMBEDDINGS_KEY = "embeddings"
1011

1112

@@ -50,21 +51,37 @@ def is_unit_vector(self) -> bool:
5051
exemplary_embedding = self.get_exemplary_embedding()
5152
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
5253

53-
@abstractmethod
54-
def embed_documents(self, elements: list[dict]) -> list[dict]:
55-
pass
54+
def get_client(self):
55+
raise NotImplementedError
5656

57-
@abstractmethod
58-
def embed_query(self, query: str) -> list[float]:
59-
pass
57+
def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
58+
raise NotImplementedError
6059

61-
def _embed_documents(self, elements: list[str]) -> list[list[float]]:
62-
results = []
63-
for text in elements:
64-
response = self.embed_query(query=text)
65-
results.append(response)
60+
def embed_documents(self, elements: list[dict]) -> list[dict]:
61+
client = self.get_client()
62+
elements = elements.copy()
63+
elements_with_text = [e for e in elements if e.get("text")]
64+
texts = [e["text"] for e in elements_with_text]
65+
embeddings = []
66+
try:
67+
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
68+
embeddings = self.embed_batch(client=client, batch=batch)
69+
embeddings.extend(embeddings)
70+
except Exception as e:
71+
raise self.wrap_error(e=e)
72+
for element, embedding in zip(elements_with_text, embeddings):
73+
element[EMBEDDINGS_KEY] = embedding
74+
return elements
75+
76+
def _embed_query(self, query: str) -> list[float]:
77+
client = self.get_client()
78+
return self.embed_batch(client=client, batch=[query])[0]
6679

67-
return results
80+
def embed_query(self, query: str) -> list[float]:
81+
try:
82+
return self._embed_query(query=query)
83+
except Exception as e:
84+
raise self.wrap_error(e=e)
6885

6986

7087
@dataclass
@@ -88,14 +105,35 @@ async def is_unit_vector(self) -> bool:
88105
exemplary_embedding = await self.get_exemplary_embedding()
89106
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
90107

91-
@abstractmethod
108+
def get_client(self):
109+
raise NotImplementedError
110+
111+
async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
112+
raise NotImplementedError
113+
92114
async def embed_documents(self, elements: list[dict]) -> list[dict]:
93-
pass
115+
client = self.get_client()
116+
elements = elements.copy()
117+
elements_with_text = [e for e in elements if e.get("text")]
118+
texts = [e["text"] for e in elements_with_text]
119+
embeddings = []
120+
try:
121+
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
122+
embeddings = await self.embed_batch(client=client, batch=batch)
123+
embeddings.extend(embeddings)
124+
except Exception as e:
125+
raise self.wrap_error(e=e)
126+
for element, embedding in zip(elements_with_text, embeddings):
127+
element[EMBEDDINGS_KEY] = embedding
128+
return elements
129+
130+
async def _embed_query(self, query: str) -> list[float]:
131+
client = self.get_client()
132+
embeddings = await self.embed_batch(client=client, batch=[query])
133+
return embeddings[0]
94134

95-
@abstractmethod
96135
async def embed_query(self, query: str) -> list[float]:
97-
pass
98-
99-
async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
100-
results = await asyncio.gather(*[self.embed_query(query=text) for text in elements])
101-
return results
136+
try:
137+
return await self._embed_query(query=query)
138+
except Exception as e:
139+
raise self.wrap_error(e=e)
Lines changed: 28 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
1-
import asyncio
21
import os
32
from dataclasses import dataclass
43
from typing import TYPE_CHECKING
54

65
from pydantic import Field, SecretStr
76

87
from unstructured_ingest.embed.interfaces import (
9-
EMBEDDINGS_KEY,
108
AsyncBaseEmbeddingEncoder,
119
BaseEmbeddingEncoder,
1210
EmbeddingConfig,
1311
)
14-
from unstructured_ingest.utils.data_prep import batch_generator
1512
from unstructured_ingest.utils.dep_check import requires_dependencies
1613

1714
USER_AGENT = "@mixedbread-ai/unstructured"
@@ -85,7 +82,7 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
8582

8683
def get_exemplary_embedding(self) -> list[float]:
8784
"""Get an exemplary embedding to determine dimensions and unit vector status."""
88-
return self._embed(["Q"])[0]
85+
return self.embed_query(query="Q")
8986

9087
@requires_dependencies(
9188
["mixedbread_ai"],
@@ -100,59 +97,19 @@ def get_request_options(self) -> "RequestOptions":
10097
additional_headers={"User-Agent": USER_AGENT},
10198
)
10299

103-
def _embed(self, texts: list[str]) -> list[list[float]]:
104-
"""
105-
Embed a list of texts using the Mixedbread AI API.
106-
107-
Args:
108-
texts (list[str]): List of texts to embed.
109-
110-
Returns:
111-
list[list[float]]: List of embeddings.
112-
"""
113-
114-
responses = []
115-
client = self.config.get_client()
116-
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
117-
response = client.embeddings(
118-
model=self.config.embedder_model_name,
119-
normalized=True,
120-
encoding_format=ENCODING_FORMAT,
121-
truncation_strategy=TRUNCATION_STRATEGY,
122-
request_options=self.get_request_options(),
123-
input=batch,
124-
)
125-
responses.append(response)
126-
return [item.embedding for response in responses for item in response.data]
127-
128-
def embed_documents(self, elements: list[dict]) -> list[dict]:
129-
"""
130-
Embed a list of document elements.
131-
132-
Args:
133-
elements (list[Element]): List of document elements.
134-
135-
Returns:
136-
list[Element]: Elements with embeddings.
137-
"""
138-
elements = elements.copy()
139-
elements_with_text = [e for e in elements if e.get("text")]
140-
embeddings = self._embed([e["text"] for e in elements_with_text])
141-
for element, embedding in zip(elements_with_text, embeddings):
142-
element[EMBEDDINGS_KEY] = embedding
143-
return elements
144-
145-
def embed_query(self, query: str) -> list[float]:
146-
"""
147-
Embed a query string.
148-
149-
Args:
150-
query (str): Query string to embed.
151-
152-
Returns:
153-
list[float]: Embedding of the query.
154-
"""
155-
return self._embed([query])[0]
100+
def get_client(self) -> "MixedbreadAI":
101+
return self.config.get_client()
102+
103+
def embed_batch(self, client: "MixedbreadAI", batch: list[str]) -> list[list[float]]:
104+
response = client.embeddings(
105+
model=self.config.embedder_model_name,
106+
normalized=True,
107+
encoding_format=ENCODING_FORMAT,
108+
truncation_strategy=TRUNCATION_STRATEGY,
109+
request_options=self.get_request_options(),
110+
input=batch,
111+
)
112+
return [datum.embedding for datum in response.data]
156113

157114

158115
@dataclass
@@ -162,8 +119,7 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
162119

163120
async def get_exemplary_embedding(self) -> list[float]:
164121
"""Get an exemplary embedding to determine dimensions and unit vector status."""
165-
embedding = await self._embed(["Q"])
166-
return embedding[0]
122+
return await self.embed_query(query="Q")
167123

168124
@requires_dependencies(
169125
["mixedbread_ai"],
@@ -178,58 +134,16 @@ def get_request_options(self) -> "RequestOptions":
178134
additional_headers={"User-Agent": USER_AGENT},
179135
)
180136

181-
async def _embed(self, texts: list[str]) -> list[list[float]]:
182-
"""
183-
Embed a list of texts using the Mixedbread AI API.
184-
185-
Args:
186-
texts (list[str]): List of texts to embed.
187-
188-
Returns:
189-
list[list[float]]: List of embeddings.
190-
"""
191-
client = self.config.get_async_client()
192-
tasks = []
193-
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
194-
tasks.append(
195-
client.embeddings(
196-
model=self.config.embedder_model_name,
197-
normalized=True,
198-
encoding_format=ENCODING_FORMAT,
199-
truncation_strategy=TRUNCATION_STRATEGY,
200-
request_options=self.get_request_options(),
201-
input=batch,
202-
)
203-
)
204-
responses = await asyncio.gather(*tasks)
205-
return [item.embedding for response in responses for item in response.data]
206-
207-
async def embed_documents(self, elements: list[dict]) -> list[dict]:
208-
"""
209-
Embed a list of document elements.
210-
211-
Args:
212-
elements (list[Element]): List of document elements.
213-
214-
Returns:
215-
list[Element]: Elements with embeddings.
216-
"""
217-
elements = elements.copy()
218-
elements_with_text = [e for e in elements if e.get("text")]
219-
embeddings = await self._embed([e["text"] for e in elements_with_text])
220-
for element, embedding in zip(elements_with_text, embeddings):
221-
element[EMBEDDINGS_KEY] = embedding
222-
return elements
223-
224-
async def embed_query(self, query: str) -> list[float]:
225-
"""
226-
Embed a query string.
227-
228-
Args:
229-
query (str): Query string to embed.
230-
231-
Returns:
232-
list[float]: Embedding of the query.
233-
"""
234-
embedding = await self._embed([query])
235-
return embedding[0]
137+
def get_client(self) -> "AsyncMixedbreadAI":
138+
return self.config.get_async_client()
139+
140+
async def embed_batch(self, client: "AsyncMixedbreadAI", batch: list[str]) -> list[list[float]]:
141+
response = await client.embeddings(
142+
model=self.config.embedder_model_name,
143+
normalized=True,
144+
encoding_format=ENCODING_FORMAT,
145+
truncation_strategy=TRUNCATION_STRATEGY,
146+
request_options=self.get_request_options(),
147+
input=batch,
148+
)
149+
return [datum.embedding for datum in response.data]

0 commit comments

Comments
 (0)