Skip to content

Commit 5d428c8

Browse files
committed
Refactoring of embeddings to use the same client
1 parent 4bc8d18 commit 5d428c8

File tree

9 files changed

+252
-378
lines changed

9 files changed

+252
-378
lines changed

app/backend/app.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@
9090
from error import error_dict, error_response
9191
from prepdocs import (
9292
OpenAIHost,
93-
clean_key_if_exists,
9493
setup_embeddings_service,
9594
setup_file_processors,
9695
setup_image_embeddings_service,
@@ -426,6 +425,11 @@ async def setup_clients():
426425
os.getenv("AZURE_OPENAI_EMB_DEPLOYMENT") if OPENAI_HOST in [OpenAIHost.AZURE, OpenAIHost.AZURE_CUSTOM] else None
427426
)
428427
AZURE_OPENAI_CUSTOM_URL = os.getenv("AZURE_OPENAI_CUSTOM_URL")
428+
AZURE_OPENAI_ENDPOINT = (
429+
os.getenv("AZURE_OPENAI_ENDPOINT")
430+
or (AZURE_OPENAI_CUSTOM_URL if OPENAI_HOST == OpenAIHost.AZURE_CUSTOM else None)
431+
or (f"https://{AZURE_OPENAI_SERVICE}.openai.azure.com" if AZURE_OPENAI_SERVICE else None)
432+
)
429433
AZURE_VISION_ENDPOINT = os.getenv("AZURE_VISION_ENDPOINT", "")
430434
AZURE_OPENAI_API_KEY_OVERRIDE = os.getenv("AZURE_OPENAI_API_KEY_OVERRIDE")
431435
# Used only with non-Azure OpenAI deployments
@@ -599,16 +603,12 @@ async def setup_clients():
599603
search_service=AZURE_SEARCH_SERVICE, index_name=AZURE_SEARCH_INDEX, azure_credential=azure_credential
600604
)
601605
text_embeddings_service = setup_embeddings_service(
602-
azure_credential=azure_credential,
603-
openai_host=OpenAIHost(OPENAI_HOST),
606+
open_ai_client=openai_client,
607+
openai_host=OPENAI_HOST,
604608
emb_model_name=OPENAI_EMB_MODEL,
605609
emb_model_dimensions=OPENAI_EMB_DIMENSIONS,
606-
azure_openai_service=AZURE_OPENAI_SERVICE,
607-
azure_openai_custom_url=AZURE_OPENAI_CUSTOM_URL,
608610
azure_openai_deployment=AZURE_OPENAI_EMB_DEPLOYMENT,
609-
azure_openai_key=clean_key_if_exists(AZURE_OPENAI_API_KEY_OVERRIDE),
610-
openai_key=clean_key_if_exists(OPENAI_API_KEY),
611-
openai_org=OPENAI_ORGANIZATION,
611+
azure_openai_endpoint=AZURE_OPENAI_ENDPOINT,
612612
disable_vectors=os.getenv("USE_VECTORS", "").lower() == "false",
613613
)
614614
image_embeddings_service = setup_image_embeddings_service(

app/backend/prepdocs.py

Lines changed: 39 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515
from load_azd_env import load_azd_env
1616
from prepdocslib.blobmanager import BlobManager
1717
from prepdocslib.csvparser import CsvParser
18-
from prepdocslib.embeddings import (
19-
AzureOpenAIEmbeddingService,
20-
ImageEmbeddings,
21-
OpenAIEmbeddingService,
22-
)
18+
from prepdocslib.embeddings import ImageEmbeddings, OpenAIEmbeddings
2319
from prepdocslib.fileprocessor import FileProcessor
2420
from prepdocslib.filestrategy import FileStrategy
2521
from prepdocslib.htmlparser import LocalHTMLParser
@@ -160,46 +156,37 @@ class OpenAIHost(str, Enum):
160156

161157

162158
def setup_embeddings_service(
163-
azure_credential: AsyncTokenCredential,
159+
open_ai_client: AsyncOpenAI,
164160
openai_host: OpenAIHost,
165161
emb_model_name: str,
166162
emb_model_dimensions: int,
167-
azure_openai_service: Optional[str],
168-
azure_openai_custom_url: Optional[str],
169-
azure_openai_deployment: Optional[str],
170-
azure_openai_key: Optional[str],
171-
openai_key: Optional[str],
172-
openai_org: Optional[str],
163+
azure_openai_deployment: str | None,
164+
azure_openai_endpoint: str | None,
173165
disable_vectors: bool = False,
174166
disable_batch_vectors: bool = False,
175167
):
176168
if disable_vectors:
177169
logger.info("Not setting up embeddings service")
178170
return None
179171

172+
azure_endpoint = None
173+
azure_deployment = None
180174
if openai_host in [OpenAIHost.AZURE, OpenAIHost.AZURE_CUSTOM]:
181-
azure_open_ai_credential: AsyncTokenCredential | AzureKeyCredential = (
182-
azure_credential if azure_openai_key is None else AzureKeyCredential(azure_openai_key)
183-
)
184-
return AzureOpenAIEmbeddingService(
185-
open_ai_service=azure_openai_service,
186-
open_ai_custom_url=azure_openai_custom_url,
187-
open_ai_deployment=azure_openai_deployment,
188-
open_ai_model_name=emb_model_name,
189-
open_ai_dimensions=emb_model_dimensions,
190-
credential=azure_open_ai_credential,
191-
disable_batch=disable_batch_vectors,
192-
)
193-
else:
194-
if openai_key is None:
195-
raise ValueError("OpenAI key is required when using the non-Azure OpenAI API")
196-
return OpenAIEmbeddingService(
197-
open_ai_model_name=emb_model_name,
198-
open_ai_dimensions=emb_model_dimensions,
199-
credential=openai_key,
200-
organization=openai_org,
201-
disable_batch=disable_batch_vectors,
202-
)
175+
if azure_openai_endpoint is None:
176+
raise ValueError("Azure OpenAI endpoint must be provided when using Azure OpenAI embeddings")
177+
if azure_openai_deployment is None:
178+
raise ValueError("Azure OpenAI deployment must be provided when using Azure OpenAI embeddings")
179+
azure_endpoint = azure_openai_endpoint
180+
azure_deployment = azure_openai_deployment
181+
182+
return OpenAIEmbeddings(
183+
open_ai_client=open_ai_client,
184+
open_ai_model_name=emb_model_name,
185+
open_ai_dimensions=emb_model_dimensions,
186+
disable_batch=disable_batch_vectors,
187+
azure_deployment_name=azure_deployment,
188+
azure_endpoint=azure_endpoint,
189+
)
203190

204191

205192
def setup_openai_client(
@@ -226,17 +213,15 @@ def setup_openai_client(
226213
logger.info("OPENAI_HOST is azure, setting up Azure OpenAI client")
227214
if not azure_openai_service:
228215
raise ValueError("AZURE_OPENAI_SERVICE must be set when OPENAI_HOST is azure")
229-
endpoint = f"https://{azure_openai_service}.openai.azure.com"
216+
endpoint = f"https://{azure_openai_service}.openai.azure.com/openai/v1"
230217
if azure_openai_api_key:
231218
logger.info("AZURE_OPENAI_API_KEY_OVERRIDE found, using as api_key for Azure OpenAI client")
232-
openai_client = AsyncOpenAI(
233-
base_url=f"{endpoint}/openai/v1", api_key=azure_openai_api_key
234-
)
219+
openai_client = AsyncOpenAI(base_url=endpoint, api_key=azure_openai_api_key)
235220
else:
236221
logger.info("Using Azure credential (passwordless authentication) for Azure OpenAI client")
237222
token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")
238223
openai_client = AsyncOpenAI(
239-
base_url=f"{endpoint}/openai/v1",
224+
base_url=endpoint,
240225
api_key=token_provider,
241226
)
242227
elif openai_host == OpenAIHost.LOCAL:
@@ -515,20 +500,6 @@ async def main(strategy: Strategy, setup_index: bool = True):
515500
emb_model_dimensions = 1536
516501
if os.getenv("AZURE_OPENAI_EMB_DIMENSIONS"):
517502
emb_model_dimensions = int(os.environ["AZURE_OPENAI_EMB_DIMENSIONS"])
518-
openai_embeddings_service = setup_embeddings_service(
519-
azure_credential=azd_credential,
520-
openai_host=OPENAI_HOST,
521-
emb_model_name=os.environ["AZURE_OPENAI_EMB_MODEL_NAME"],
522-
emb_model_dimensions=emb_model_dimensions,
523-
azure_openai_service=os.getenv("AZURE_OPENAI_SERVICE"),
524-
azure_openai_custom_url=os.getenv("AZURE_OPENAI_CUSTOM_URL"),
525-
azure_openai_deployment=os.getenv("AZURE_OPENAI_EMB_DEPLOYMENT"),
526-
azure_openai_key=os.getenv("AZURE_OPENAI_API_KEY_OVERRIDE"),
527-
openai_key=clean_key_if_exists(os.getenv("OPENAI_API_KEY")),
528-
openai_org=os.getenv("OPENAI_ORGANIZATION"),
529-
disable_vectors=dont_use_vectors,
530-
disable_batch_vectors=args.disablebatchvectors,
531-
)
532503
openai_client = setup_openai_client(
533504
openai_host=OPENAI_HOST,
534505
azure_credential=azd_credential,
@@ -538,11 +509,25 @@ async def main(strategy: Strategy, setup_index: bool = True):
538509
openai_api_key=clean_key_if_exists(os.getenv("OPENAI_API_KEY")),
539510
openai_organization=os.getenv("OPENAI_ORGANIZATION"),
540511
)
512+
azure_embedding_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("AZURE_OPENAI_CUSTOM_URL")
513+
if not azure_embedding_endpoint and OPENAI_HOST == OpenAIHost.AZURE:
514+
if service := os.getenv("AZURE_OPENAI_SERVICE"):
515+
azure_embedding_endpoint = f"https://{service}.openai.azure.com"
516+
openai_embeddings_service = setup_embeddings_service(
517+
open_ai_client=openai_client,
518+
openai_host=OPENAI_HOST,
519+
emb_model_name=os.environ["AZURE_OPENAI_EMB_MODEL_NAME"],
520+
emb_model_dimensions=emb_model_dimensions,
521+
azure_openai_deployment=os.getenv("AZURE_OPENAI_EMB_DEPLOYMENT"),
522+
azure_openai_endpoint=azure_embedding_endpoint,
523+
disable_vectors=dont_use_vectors,
524+
disable_batch_vectors=args.disablebatchvectors,
525+
)
541526

542527
ingestion_strategy: Strategy
543528
if use_int_vectorization:
544529

545-
if not openai_embeddings_service or not isinstance(openai_embeddings_service, AzureOpenAIEmbeddingService):
530+
if not openai_embeddings_service or OPENAI_HOST not in [OpenAIHost.AZURE, OpenAIHost.AZURE_CUSTOM]:
546531
raise Exception("Integrated vectorization strategy requires an Azure OpenAI embeddings service")
547532

548533
ingestion_strategy = IntegratedVectorizerStrategy(

app/backend/prepdocslib/embeddings.py

Lines changed: 25 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
import logging
22
from abc import ABC
33
from collections.abc import Awaitable, Callable
4-
from typing import Optional
54
from urllib.parse import urljoin
65

76
import aiohttp
87
import tiktoken
9-
from azure.core.credentials import AzureKeyCredential
10-
from azure.core.credentials_async import AsyncTokenCredential
11-
from azure.identity.aio import get_bearer_token_provider
128
from openai import AsyncOpenAI, RateLimitError
139
from tenacity import (
1410
AsyncRetrying,
@@ -22,9 +18,7 @@
2218

2319

2420
class EmbeddingBatch:
25-
"""
26-
Represents a batch of text that is going to be embedded
27-
"""
21+
"""Represents a batch of text that is going to be embedded."""
2822

2923
def __init__(self, texts: list[str], token_length: int):
3024
self.texts = texts
@@ -36,12 +30,9 @@ class ExtraArgs(TypedDict, total=False):
3630

3731

3832
class OpenAIEmbeddings(ABC):
39-
"""
40-
Contains common logic across both OpenAI and Azure OpenAI embedding services
41-
Can split source text into batches for more efficient embedding calls
42-
"""
33+
"""Client wrapper that handles batching, retries, and token accounting."""
4334

44-
SUPPORTED_BATCH_AOAI_MODEL = {
35+
SUPPORTED_BATCH_MODEL = {
4536
"text-embedding-ada-002": {"token_limit": 8100, "max_batch_size": 16},
4637
"text-embedding-3-small": {"token_limit": 8100, "max_batch_size": 16},
4738
"text-embedding-3-large": {"token_limit": 8100, "max_batch_size": 16},
@@ -52,13 +43,26 @@ class OpenAIEmbeddings(ABC):
5243
"text-embedding-3-large": True,
5344
}
5445

55-
def __init__(self, open_ai_model_name: str, open_ai_dimensions: int, disable_batch: bool = False):
46+
def __init__(
47+
self,
48+
open_ai_client: AsyncOpenAI,
49+
open_ai_model_name: str,
50+
open_ai_dimensions: int,
51+
*,
52+
disable_batch: bool = False,
53+
azure_deployment_name: str | None = None,
54+
azure_endpoint: str | None = None,
55+
):
56+
self.open_ai_client = open_ai_client
5657
self.open_ai_model_name = open_ai_model_name
5758
self.open_ai_dimensions = open_ai_dimensions
5859
self.disable_batch = disable_batch
60+
self.azure_deployment_name = azure_deployment_name
61+
self.azure_endpoint = azure_endpoint.rstrip("/") if azure_endpoint else None
5962

60-
async def create_client(self) -> AsyncOpenAI:
61-
raise NotImplementedError
63+
@property
64+
def _api_model(self) -> str:
65+
return self.azure_deployment_name or self.open_ai_model_name
6266

6367
def before_retry_sleep(self, retry_state):
6468
logger.info("Rate limited on the OpenAI embeddings API, sleeping before retrying...")
@@ -68,7 +72,7 @@ def calculate_token_length(self, text: str):
6872
return len(encoding.encode(text))
6973

7074
def split_text_into_batches(self, texts: list[str]) -> list[EmbeddingBatch]:
71-
batch_info = OpenAIEmbeddings.SUPPORTED_BATCH_AOAI_MODEL.get(self.open_ai_model_name)
75+
batch_info = OpenAIEmbeddings.SUPPORTED_BATCH_MODEL.get(self.open_ai_model_name)
7276
if not batch_info:
7377
raise NotImplementedError(
7478
f"Model {self.open_ai_model_name} is not supported with batch embedding operations"
@@ -101,7 +105,6 @@ def split_text_into_batches(self, texts: list[str]) -> list[EmbeddingBatch]:
101105
async def create_embedding_batch(self, texts: list[str], dimensions_args: ExtraArgs) -> list[list[float]]:
102106
batches = self.split_text_into_batches(texts)
103107
embeddings = []
104-
client = await self.create_client()
105108
for batch in batches:
106109
async for attempt in AsyncRetrying(
107110
retry=retry_if_exception_type(RateLimitError),
@@ -110,8 +113,8 @@ async def create_embedding_batch(self, texts: list[str], dimensions_args: ExtraA
110113
before_sleep=self.before_retry_sleep,
111114
):
112115
with attempt:
113-
emb_response = await client.embeddings.create(
114-
model=self.open_ai_model_name, input=batch.texts, **dimensions_args
116+
emb_response = await self.open_ai_client.embeddings.create(
117+
model=self._api_model, input=batch.texts, **dimensions_args
115118
)
116119
embeddings.extend([data.embedding for data in emb_response.data])
117120
logger.info(
@@ -123,16 +126,15 @@ async def create_embedding_batch(self, texts: list[str], dimensions_args: ExtraA
123126
return embeddings
124127

125128
async def create_embedding_single(self, text: str, dimensions_args: ExtraArgs) -> list[float]:
126-
client = await self.create_client()
127129
async for attempt in AsyncRetrying(
128130
retry=retry_if_exception_type(RateLimitError),
129131
wait=wait_random_exponential(min=15, max=60),
130132
stop=stop_after_attempt(15),
131133
before_sleep=self.before_retry_sleep,
132134
):
133135
with attempt:
134-
emb_response = await client.embeddings.create(
135-
model=self.open_ai_model_name, input=text, **dimensions_args
136+
emb_response = await self.open_ai_client.embeddings.create(
137+
model=self._api_model, input=text, **dimensions_args
136138
)
137139
logger.info("Computed embedding for text section. Character count: %d", len(text))
138140

@@ -146,85 +148,12 @@ async def create_embeddings(self, texts: list[str]) -> list[list[float]]:
146148
else {}
147149
)
148150

149-
if not self.disable_batch and self.open_ai_model_name in OpenAIEmbeddings.SUPPORTED_BATCH_AOAI_MODEL:
151+
if not self.disable_batch and self.open_ai_model_name in OpenAIEmbeddings.SUPPORTED_BATCH_MODEL:
150152
return await self.create_embedding_batch(texts, dimensions_args)
151153

152154
return [await self.create_embedding_single(text, dimensions_args) for text in texts]
153155

154156

155-
class AzureOpenAIEmbeddingService(OpenAIEmbeddings):
156-
"""
157-
Class for using Azure OpenAI embeddings
158-
To learn more please visit https://learn.microsoft.com/azure/ai-services/openai/concepts/understand-embeddings
159-
"""
160-
161-
def __init__(
162-
self,
163-
open_ai_service: Optional[str],
164-
open_ai_deployment: Optional[str],
165-
open_ai_model_name: str,
166-
open_ai_dimensions: int,
167-
credential: AsyncTokenCredential | AzureKeyCredential,
168-
open_ai_custom_url: Optional[str] = None,
169-
disable_batch: bool = False,
170-
):
171-
super().__init__(open_ai_deployment or open_ai_model_name, open_ai_dimensions, disable_batch)
172-
self.open_ai_service = open_ai_service
173-
if open_ai_service:
174-
self.open_ai_endpoint = f"https://{open_ai_service}.openai.azure.com"
175-
elif open_ai_custom_url:
176-
self.open_ai_endpoint = open_ai_custom_url
177-
else:
178-
raise ValueError("Either open_ai_service or open_ai_custom_url must be provided")
179-
self.open_ai_deployment = open_ai_deployment
180-
self.credential = credential
181-
182-
async def create_client(self) -> AsyncOpenAI:
183-
class AuthArgs(TypedDict, total=False):
184-
api_key: str
185-
186-
auth_args = AuthArgs()
187-
if isinstance(self.credential, AzureKeyCredential):
188-
auth_args["api_key"] = self.credential.key
189-
elif isinstance(self.credential, AsyncTokenCredential):
190-
token_provider = get_bearer_token_provider(
191-
self.credential, "https://cognitiveservices.azure.com/.default"
192-
)
193-
auth_args["api_key"] = token_provider
194-
else:
195-
raise TypeError("Invalid credential type")
196-
197-
# For Azure OpenAI, we need to use the v1 endpoint
198-
base_url = f"{self.open_ai_endpoint}/openai/v1"
199-
200-
return AsyncOpenAI(
201-
base_url=base_url,
202-
**auth_args,
203-
)
204-
205-
206-
class OpenAIEmbeddingService(OpenAIEmbeddings):
207-
"""
208-
Class for using OpenAI embeddings
209-
To learn more please visit https://platform.openai.com/docs/guides/embeddings
210-
"""
211-
212-
def __init__(
213-
self,
214-
open_ai_model_name: str,
215-
open_ai_dimensions: int,
216-
credential: str,
217-
organization: Optional[str] = None,
218-
disable_batch: bool = False,
219-
):
220-
super().__init__(open_ai_model_name, open_ai_dimensions, disable_batch)
221-
self.credential = credential
222-
self.organization = organization
223-
224-
async def create_client(self) -> AsyncOpenAI:
225-
return AsyncOpenAI(api_key=self.credential, organization=self.organization)
226-
227-
228157
class ImageEmbeddings:
229158
"""
230159
Class for using image embeddings from Azure AI Vision

0 commit comments

Comments
 (0)