Skip to content

Commit 0d65b4c

Browse files
feat: Enhance error handling in Azure document embedder (#8941)
* feat: Enhance error handling in Azure document embedder * add release notes * address review comments * Update releasenotes/notes/add-azure-embedder-exception-handler-c10ea46fb536de3b.yaml Co-authored-by: Stefano Fiorucci <[email protected]> * more alignment with OpenAI impl --------- Co-authored-by: Stefano Fiorucci <[email protected]>
1 parent 28db039 commit 0d65b4c

File tree

3 files changed

+56
-15
lines changed

3 files changed

+56
-15
lines changed

haystack/components/embedders/azure_document_embedder.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55
import os
66
from typing import Any, Dict, List, Optional, Tuple
77

8+
from more_itertools import batched
9+
from openai import APIError
810
from openai.lib.azure import AzureOpenAI
911
from tqdm import tqdm
1012

11-
from haystack import Document, component, default_from_dict, default_to_dict
13+
from haystack import Document, component, default_from_dict, default_to_dict, logging
1214
from haystack.utils import Secret, deserialize_secrets_inplace
1315

16+
logger = logging.getLogger(__name__)
17+
1418

1519
@component
1620
class AzureOpenAIDocumentEmbedder:
@@ -182,11 +186,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIDocumentEmbedder":
182186
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
183187
return default_from_dict(cls, data)
184188

185-
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
189+
def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]:
186190
"""
187191
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
188192
"""
189-
texts_to_embed = []
193+
texts_to_embed = {}
190194
for doc in documents:
191195
meta_values_to_embed = [
192196
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
@@ -196,27 +200,35 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
196200
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
197201
).replace("\n", " ")
198202

199-
texts_to_embed.append(text_to_embed)
203+
texts_to_embed[doc.id] = text_to_embed
200204
return texts_to_embed
201205

202-
def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
206+
def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
203207
"""
204208
Embed a list of texts in batches.
205209
"""
206210

207211
all_embeddings: List[List[float]] = []
208212
meta: Dict[str, Any] = {"model": "", "usage": {"prompt_tokens": 0, "total_tokens": 0}}
209-
for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Texts"):
210-
batch = texts_to_embed[i : i + batch_size]
211-
if self.dimensions is not None:
212-
response = self._client.embeddings.create(
213-
model=self.azure_deployment, dimensions=self.dimensions, input=batch
214-
)
215-
else:
216-
response = self._client.embeddings.create(model=self.azure_deployment, input=batch)
217213

218-
# Append embeddings to the list
219-
all_embeddings.extend(el.embedding for el in response.data)
214+
for batch in tqdm(
215+
batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
216+
):
217+
args: Dict[str, Any] = {"model": self.azure_deployment, "input": [b[1] for b in batch]}
218+
219+
if self.dimensions is not None:
220+
args["dimensions"] = self.dimensions
221+
222+
try:
223+
response = self._client.embeddings.create(**args)
224+
except APIError as e:
225+
# Log the error but continue processing
226+
ids = ", ".join(b[0] for b in batch)
227+
logger.exception(f"Failed embedding of documents {ids} caused by {e}")
228+
continue
229+
230+
embeddings = [el.embedding for el in response.data]
231+
all_embeddings.extend(embeddings)
220232

221233
# Update the meta information only once if it's empty
222234
if not meta["model"]:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
enhancements:
3+
- |
4+
Improved AzureDocumentEmbedder to handle embedding generation failures gracefully.
5+
Errors are logged, and processing continues with the remaining batches.

test/components/embedders/test_azure_document_embedder.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
# SPDX-License-Identifier: Apache-2.0
44
import os
55

6+
from openai import APIError
7+
8+
from haystack.utils.auth import Secret
69
import pytest
710

811
from haystack import Document
912
from haystack.components.embedders import AzureOpenAIDocumentEmbedder
13+
from unittest.mock import Mock, patch
1014

1115

1216
class TestAzureOpenAIDocumentEmbedder:
@@ -83,6 +87,26 @@ def test_from_dict(self, monkeypatch):
8387
assert component.suffix == ""
8488
assert component.default_headers == {}
8589

90+
def test_embed_batch_handles_exceptions_gracefully(self, caplog):
91+
embedder = AzureOpenAIDocumentEmbedder(
92+
azure_endpoint="https://test.openai.azure.com",
93+
api_key=Secret.from_token("fake-api-key"),
94+
azure_deployment="text-embedding-ada-002",
95+
embedding_separator=" | ",
96+
)
97+
98+
fake_texts_to_embed = {"1": "text1", "2": "text2"}
99+
100+
with patch.object(
101+
embedder._client.embeddings,
102+
"create",
103+
side_effect=APIError(message="Mocked error", request=Mock(), body=None),
104+
):
105+
embedder._embed_batch(texts_to_embed=fake_texts_to_embed, batch_size=32)
106+
107+
assert len(caplog.records) == 1
108+
assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog.text
109+
86110
@pytest.mark.integration
87111
@pytest.mark.skipif(
88112
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),

0 commit comments

Comments
 (0)