55import os
66from typing import Any , Dict , List , Optional , Tuple
77
8+ from more_itertools import batched
9+ from openai import APIError
810from openai .lib .azure import AzureOpenAI
911from tqdm import tqdm
1012
11- from haystack import Document , component , default_from_dict , default_to_dict
13+ from haystack import Document , component , default_from_dict , default_to_dict , logging
1214from haystack .utils import Secret , deserialize_secrets_inplace
1315
16+ logger = logging .getLogger (__name__ )
17+
1418
1519@component
1620class AzureOpenAIDocumentEmbedder :
@@ -182,11 +186,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIDocumentEmbedder":
182186 deserialize_secrets_inplace (data ["init_parameters" ], keys = ["api_key" , "azure_ad_token" ])
183187 return default_from_dict (cls , data )
184188
185- def _prepare_texts_to_embed (self , documents : List [Document ]) -> List [ str ]:
189+ def _prepare_texts_to_embed (self , documents : List [Document ]) -> Dict [ str , str ]:
186190 """
187191 Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
188192 """
189- texts_to_embed = []
193+ texts_to_embed = {}
190194 for doc in documents :
191195 meta_values_to_embed = [
192196 str (doc .meta [key ]) for key in self .meta_fields_to_embed if key in doc .meta and doc .meta [key ] is not None
@@ -196,27 +200,35 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
196200 self .prefix + self .embedding_separator .join (meta_values_to_embed + [doc .content or "" ]) + self .suffix
197201 ).replace ("\n " , " " )
198202
199- texts_to_embed . append ( text_to_embed )
203+ texts_to_embed [ doc . id ] = text_to_embed
200204 return texts_to_embed
201205
202- def _embed_batch (self , texts_to_embed : List [ str ], batch_size : int ) -> Tuple [List [List [float ]], Dict [str , Any ]]:
206+ def _embed_batch (self , texts_to_embed : Dict [ str , str ], batch_size : int ) -> Tuple [List [List [float ]], Dict [str , Any ]]:
203207 """
204208 Embed a list of texts in batches.
205209 """
206210
207211 all_embeddings : List [List [float ]] = []
208212 meta : Dict [str , Any ] = {"model" : "" , "usage" : {"prompt_tokens" : 0 , "total_tokens" : 0 }}
209- for i in tqdm (range (0 , len (texts_to_embed ), batch_size ), desc = "Embedding Texts" ):
210- batch = texts_to_embed [i : i + batch_size ]
211- if self .dimensions is not None :
212- response = self ._client .embeddings .create (
213- model = self .azure_deployment , dimensions = self .dimensions , input = batch
214- )
215- else :
216- response = self ._client .embeddings .create (model = self .azure_deployment , input = batch )
217213
218- # Append embeddings to the list
219- all_embeddings .extend (el .embedding for el in response .data )
214+ for batch in tqdm (
215+ batched (texts_to_embed .items (), batch_size ), disable = not self .progress_bar , desc = "Calculating embeddings"
216+ ):
217+ args : Dict [str , Any ] = {"model" : self .azure_deployment , "input" : [b [1 ] for b in batch ]}
218+
219+ if self .dimensions is not None :
220+ args ["dimensions" ] = self .dimensions
221+
222+ try :
223+ response = self ._client .embeddings .create (** args )
224+ except APIError as e :
225+ # Log the error but continue processing
226+ ids = ", " .join (b [0 ] for b in batch )
227+ logger .exception (f"Failed embedding of documents { ids } caused by { e } " )
228+ continue
229+
230+ embeddings = [el .embedding for el in response .data ]
231+ all_embeddings .extend (embeddings )
220232
221233 # Update the meta information only once if it's empty
222234 if not meta ["model" ]:
0 commit comments