11import logging
22from abc import ABC
33from collections .abc import Awaitable , Callable
4- from typing import Optional
54from urllib .parse import urljoin
65
76import aiohttp
87import tiktoken
9- from azure .core .credentials import AzureKeyCredential
10- from azure .core .credentials_async import AsyncTokenCredential
11- from azure .identity .aio import get_bearer_token_provider
128from openai import AsyncOpenAI , RateLimitError
139from tenacity import (
1410 AsyncRetrying ,
2218
2319
2420class EmbeddingBatch :
25- """
26- Represents a batch of text that is going to be embedded
27- """
21+ """Represents a batch of text that is going to be embedded."""
2822
2923 def __init__ (self , texts : list [str ], token_length : int ):
3024 self .texts = texts
@@ -36,12 +30,9 @@ class ExtraArgs(TypedDict, total=False):
3630
3731
3832class OpenAIEmbeddings (ABC ):
39- """
40- Contains common logic across both OpenAI and Azure OpenAI embedding services
41- Can split source text into batches for more efficient embedding calls
42- """
33+ """Client wrapper that handles batching, retries, and token accounting."""
4334
44- SUPPORTED_BATCH_AOAI_MODEL = {
35+ SUPPORTED_BATCH_MODEL = {
4536 "text-embedding-ada-002" : {"token_limit" : 8100 , "max_batch_size" : 16 },
4637 "text-embedding-3-small" : {"token_limit" : 8100 , "max_batch_size" : 16 },
4738 "text-embedding-3-large" : {"token_limit" : 8100 , "max_batch_size" : 16 },
@@ -52,13 +43,26 @@ class OpenAIEmbeddings(ABC):
5243 "text-embedding-3-large" : True ,
5344 }
5445
55- def __init__ (self , open_ai_model_name : str , open_ai_dimensions : int , disable_batch : bool = False ):
46+ def __init__ (
47+ self ,
48+ open_ai_client : AsyncOpenAI ,
49+ open_ai_model_name : str ,
50+ open_ai_dimensions : int ,
51+ * ,
52+ disable_batch : bool = False ,
53+ azure_deployment_name : str | None = None ,
54+ azure_endpoint : str | None = None ,
55+ ):
56+ self .open_ai_client = open_ai_client
5657 self .open_ai_model_name = open_ai_model_name
5758 self .open_ai_dimensions = open_ai_dimensions
5859 self .disable_batch = disable_batch
60+ self .azure_deployment_name = azure_deployment_name
61+ self .azure_endpoint = azure_endpoint .rstrip ("/" ) if azure_endpoint else None
5962
60- async def create_client (self ) -> AsyncOpenAI :
61- raise NotImplementedError
63+ @property
64+ def _api_model (self ) -> str :
65+ return self .azure_deployment_name or self .open_ai_model_name
6266
6367 def before_retry_sleep (self , retry_state ):
6468 logger .info ("Rate limited on the OpenAI embeddings API, sleeping before retrying..." )
@@ -68,7 +72,7 @@ def calculate_token_length(self, text: str):
6872 return len (encoding .encode (text ))
6973
7074 def split_text_into_batches (self , texts : list [str ]) -> list [EmbeddingBatch ]:
71- batch_info = OpenAIEmbeddings .SUPPORTED_BATCH_AOAI_MODEL .get (self .open_ai_model_name )
75+ batch_info = OpenAIEmbeddings .SUPPORTED_BATCH_MODEL .get (self .open_ai_model_name )
7276 if not batch_info :
7377 raise NotImplementedError (
7478 f"Model { self .open_ai_model_name } is not supported with batch embedding operations"
@@ -101,7 +105,6 @@ def split_text_into_batches(self, texts: list[str]) -> list[EmbeddingBatch]:
101105 async def create_embedding_batch (self , texts : list [str ], dimensions_args : ExtraArgs ) -> list [list [float ]]:
102106 batches = self .split_text_into_batches (texts )
103107 embeddings = []
104- client = await self .create_client ()
105108 for batch in batches :
106109 async for attempt in AsyncRetrying (
107110 retry = retry_if_exception_type (RateLimitError ),
@@ -110,8 +113,8 @@ async def create_embedding_batch(self, texts: list[str], dimensions_args: ExtraA
110113 before_sleep = self .before_retry_sleep ,
111114 ):
112115 with attempt :
113- emb_response = await client .embeddings .create (
114- model = self .open_ai_model_name , input = batch .texts , ** dimensions_args
116+ emb_response = await self . open_ai_client .embeddings .create (
117+ model = self ._api_model , input = batch .texts , ** dimensions_args
115118 )
116119 embeddings .extend ([data .embedding for data in emb_response .data ])
117120 logger .info (
@@ -123,16 +126,15 @@ async def create_embedding_batch(self, texts: list[str], dimensions_args: ExtraA
123126 return embeddings
124127
125128 async def create_embedding_single (self , text : str , dimensions_args : ExtraArgs ) -> list [float ]:
126- client = await self .create_client ()
127129 async for attempt in AsyncRetrying (
128130 retry = retry_if_exception_type (RateLimitError ),
129131 wait = wait_random_exponential (min = 15 , max = 60 ),
130132 stop = stop_after_attempt (15 ),
131133 before_sleep = self .before_retry_sleep ,
132134 ):
133135 with attempt :
134- emb_response = await client .embeddings .create (
135- model = self .open_ai_model_name , input = text , ** dimensions_args
136+ emb_response = await self . open_ai_client .embeddings .create (
137+ model = self ._api_model , input = text , ** dimensions_args
136138 )
137139 logger .info ("Computed embedding for text section. Character count: %d" , len (text ))
138140
@@ -146,85 +148,12 @@ async def create_embeddings(self, texts: list[str]) -> list[list[float]]:
146148 else {}
147149 )
148150
149- if not self .disable_batch and self .open_ai_model_name in OpenAIEmbeddings .SUPPORTED_BATCH_AOAI_MODEL :
151+ if not self .disable_batch and self .open_ai_model_name in OpenAIEmbeddings .SUPPORTED_BATCH_MODEL :
150152 return await self .create_embedding_batch (texts , dimensions_args )
151153
152154 return [await self .create_embedding_single (text , dimensions_args ) for text in texts ]
153155
154156
155- class AzureOpenAIEmbeddingService (OpenAIEmbeddings ):
156- """
157- Class for using Azure OpenAI embeddings
158- To learn more please visit https://learn.microsoft.com/azure/ai-services/openai/concepts/understand-embeddings
159- """
160-
161- def __init__ (
162- self ,
163- open_ai_service : Optional [str ],
164- open_ai_deployment : Optional [str ],
165- open_ai_model_name : str ,
166- open_ai_dimensions : int ,
167- credential : AsyncTokenCredential | AzureKeyCredential ,
168- open_ai_custom_url : Optional [str ] = None ,
169- disable_batch : bool = False ,
170- ):
171- super ().__init__ (open_ai_deployment or open_ai_model_name , open_ai_dimensions , disable_batch )
172- self .open_ai_service = open_ai_service
173- if open_ai_service :
174- self .open_ai_endpoint = f"https://{ open_ai_service } .openai.azure.com"
175- elif open_ai_custom_url :
176- self .open_ai_endpoint = open_ai_custom_url
177- else :
178- raise ValueError ("Either open_ai_service or open_ai_custom_url must be provided" )
179- self .open_ai_deployment = open_ai_deployment
180- self .credential = credential
181-
182- async def create_client (self ) -> AsyncOpenAI :
183- class AuthArgs (TypedDict , total = False ):
184- api_key : str
185-
186- auth_args = AuthArgs ()
187- if isinstance (self .credential , AzureKeyCredential ):
188- auth_args ["api_key" ] = self .credential .key
189- elif isinstance (self .credential , AsyncTokenCredential ):
190- token_provider = get_bearer_token_provider (
191- self .credential , "https://cognitiveservices.azure.com/.default"
192- )
193- auth_args ["api_key" ] = token_provider
194- else :
195- raise TypeError ("Invalid credential type" )
196-
197- # For Azure OpenAI, we need to use the v1 endpoint
198- base_url = f"{ self .open_ai_endpoint } /openai/v1"
199-
200- return AsyncOpenAI (
201- base_url = base_url ,
202- ** auth_args ,
203- )
204-
205-
206- class OpenAIEmbeddingService (OpenAIEmbeddings ):
207- """
208- Class for using OpenAI embeddings
209- To learn more please visit https://platform.openai.com/docs/guides/embeddings
210- """
211-
212- def __init__ (
213- self ,
214- open_ai_model_name : str ,
215- open_ai_dimensions : int ,
216- credential : str ,
217- organization : Optional [str ] = None ,
218- disable_batch : bool = False ,
219- ):
220- super ().__init__ (open_ai_model_name , open_ai_dimensions , disable_batch )
221- self .credential = credential
222- self .organization = organization
223-
224- async def create_client (self ) -> AsyncOpenAI :
225- return AsyncOpenAI (api_key = self .credential , organization = self .organization )
226-
227-
228157class ImageEmbeddings :
229158 """
230159 Class for using image embeddings from Azure AI Vision
0 commit comments