Skip to content

Commit 78383ec

Browse files
committed
Use ImageEmbeddings client directly
1 parent 7c37e40 commit 78383ec

File tree

8 files changed

+58
-59
lines changed

8 files changed

+58
-59
lines changed

app/backend/app.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
setup_search_info,
104104
)
105105
from prepdocslib.blobmanager import AdlsBlobManager
106+
from prepdocslib.embeddings import ImageEmbeddings
106107
from prepdocslib.filestrategy import UploadUserFileStrategy
107108
from prepdocslib.listfilestrategy import File
108109

@@ -624,6 +625,10 @@ async def setup_clients():
624625
)
625626
current_app.config[CONFIG_INGESTER] = ingester
626627

628+
image_embeddings_client = None
629+
if USE_MULTIMODAL:
630+
image_embeddings_client = ImageEmbeddings(AZURE_VISION_ENDPOINT, azure_ai_token_provider)
631+
627632
current_app.config[CONFIG_OPENAI_CLIENT] = openai_client
628633
current_app.config[CONFIG_SEARCH_CLIENT] = search_client
629634
current_app.config[CONFIG_AGENT_CLIENT] = agent_client
@@ -659,6 +664,7 @@ async def setup_clients():
659664

660665
# Set up the two default RAG approaches for /ask and /chat
661666
# RetrieveThenReadApproach is used by /ask for single-turn Q&A
667+
662668
current_app.config[CONFIG_ASK_APPROACH] = RetrieveThenReadApproach(
663669
search_client=search_client,
664670
search_index_name=AZURE_SEARCH_INDEX,
@@ -667,8 +673,6 @@ async def setup_clients():
667673
agent_client=agent_client,
668674
openai_client=openai_client,
669675
auth_helper=auth_helper,
670-
image_blob_container_client=image_blob_container_client,
671-
image_datalake_client=user_blob_container_client,
672676
chatgpt_model=OPENAI_CHATGPT_MODEL,
673677
chatgpt_deployment=AZURE_OPENAI_CHATGPT_DEPLOYMENT,
674678
embedding_model=OPENAI_EMB_MODEL,
@@ -681,9 +685,10 @@ async def setup_clients():
681685
query_speller=AZURE_SEARCH_QUERY_SPELLER,
682686
prompt_manager=prompt_manager,
683687
reasoning_effort=OPENAI_REASONING_EFFORT,
684-
vision_endpoint=AZURE_VISION_ENDPOINT,
685-
vision_token_provider=azure_ai_token_provider,
686688
multimodal_enabled=USE_MULTIMODAL,
689+
image_embeddings_client=image_embeddings_client,
690+
image_blob_container_client=image_blob_container_client,
691+
image_datalake_client=user_blob_container_client,
687692
)
688693

689694
# ChatReadRetrieveReadApproach is used by /chat for multi-turn conversation
@@ -695,8 +700,6 @@ async def setup_clients():
695700
agent_client=agent_client,
696701
openai_client=openai_client,
697702
auth_helper=auth_helper,
698-
image_blob_container_client=image_blob_container_client,
699-
image_datalake_client=user_blob_container_client,
700703
chatgpt_model=OPENAI_CHATGPT_MODEL,
701704
chatgpt_deployment=AZURE_OPENAI_CHATGPT_DEPLOYMENT,
702705
embedding_model=OPENAI_EMB_MODEL,
@@ -709,9 +712,10 @@ async def setup_clients():
709712
query_speller=AZURE_SEARCH_QUERY_SPELLER,
710713
prompt_manager=prompt_manager,
711714
reasoning_effort=OPENAI_REASONING_EFFORT,
712-
vision_endpoint=AZURE_VISION_ENDPOINT,
713-
vision_token_provider=azure_ai_token_provider,
714715
multimodal_enabled=USE_MULTIMODAL,
716+
image_embeddings_client=image_embeddings_client,
717+
image_blob_container_client=image_blob_container_client,
718+
image_datalake_client=user_blob_container_client,
715719
)
716720

717721

app/backend/approaches/approach.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
from collections.abc import AsyncGenerator, Awaitable
33
from dataclasses import dataclass, field
44
from enum import Enum
5-
from typing import Any, Callable, Optional, TypedDict, Union, cast
6-
from urllib.parse import urljoin
5+
from typing import Any, Optional, TypedDict, Union, cast
76

8-
import aiohttp
97
from azure.search.documents.agent.aio import KnowledgeAgentRetrievalClient
108
from azure.search.documents.agent.models import (
119
KnowledgeAgentAzureSearchDocReference,
@@ -38,6 +36,7 @@
3836
from approaches.promptmanager import PromptManager
3937
from core.authentication import AuthenticationHelper
4038
from core.imageshelper import download_blob_as_base64
39+
from prepdocslib.embeddings import ImageEmbeddings
4140

4241

4342
class LLMInputType(str, Enum):
@@ -174,8 +173,7 @@ def __init__(
174173
prompt_manager: PromptManager,
175174
reasoning_effort: Optional[str] = None,
176175
multimodal_enabled: bool = False,
177-
vision_endpoint: Optional[str] = None,
178-
vision_token_provider: Optional[Callable[[], Awaitable[str]]] = None,
176+
image_embeddings_client: Optional[ImageEmbeddings] = None,
179177
image_blob_container_client: Optional[ContainerClient] = None,
180178
image_datalake_client: Optional[FileSystemClient] = None,
181179
):
@@ -193,8 +191,7 @@ def __init__(
193191
self.reasoning_effort = reasoning_effort
194192
self.include_token_usage = True
195193
self.multimodal_enabled = multimodal_enabled
196-
self.vision_endpoint = vision_endpoint
197-
self.vision_token_provider = vision_token_provider
194+
self.image_embeddings_client = image_embeddings_client
198195
self.image_blob_container_client = image_blob_container_client
199196
self.image_datalake_client = image_datalake_client
200197

@@ -462,25 +459,9 @@ class ExtraArgs(TypedDict, total=False):
462459
# so we do not need to explicitly pass in an oversampling parameter here
463460
return VectorizedQuery(vector=query_vector, k_nearest_neighbors=50, fields=self.embedding_field)
464461

465-
async def compute_image_embedding(self, q: str):
466-
if not self.vision_endpoint:
467-
raise ValueError("Azure AI Vision endpoint must be set to compute image embedding.")
468-
endpoint = urljoin(self.vision_endpoint, "computervision/retrieval:vectorizeText")
469-
headers = {"Content-Type": "application/json"}
470-
params = {"api-version": "2024-02-01", "model-version": "2023-04-15"}
471-
data = {"text": q}
472-
473-
if not self.vision_token_provider:
474-
raise ValueError("Azure AI Vision token provider must be set to compute image embedding.")
475-
headers["Authorization"] = "Bearer " + await self.vision_token_provider()
476-
477-
async with aiohttp.ClientSession() as session:
478-
async with session.post(
479-
url=endpoint, params=params, headers=headers, json=data, raise_for_status=True
480-
) as response:
481-
json = await response.json()
482-
image_query_vector = json["vector"]
483-
return VectorizedQuery(vector=image_query_vector, k_nearest_neighbors=50, fields="images/embedding")
462+
async def compute_multimodal_embedding(self, q: str):
463+
multimodal_query_vector = await self.image_embeddings_client.create_embedding_for_text(q)
464+
return VectorizedQuery(vector=multimodal_query_vector, k_nearest_neighbors=50, fields="images/embedding")
484465

485466
def get_system_prompt_variables(self, override_prompt: Optional[str]) -> dict[str, str]:
486467
# Allows client to replace the entire prompt, or to inject into the existing prompt using >>>

app/backend/approaches/chatreadretrieveread.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import re
33
from collections.abc import AsyncGenerator, Awaitable
4-
from typing import Any, Callable, Optional, Union, cast
4+
from typing import Any, Optional, Union, cast
55

66
from azure.search.documents.agent.aio import KnowledgeAgentRetrievalClient
77
from azure.search.documents.aio import SearchClient
@@ -26,6 +26,7 @@
2626
)
2727
from approaches.promptmanager import PromptManager
2828
from core.authentication import AuthenticationHelper
29+
from prepdocslib.embeddings import ImageEmbeddings
2930

3031

3132
class ChatReadRetrieveReadApproach(Approach):
@@ -60,8 +61,7 @@ def __init__(
6061
prompt_manager: PromptManager,
6162
reasoning_effort: Optional[str] = None,
6263
multimodal_enabled: bool = False,
63-
vision_endpoint: Optional[str] = None,
64-
vision_token_provider: Optional[Callable[[], Awaitable[str]]] = None,
64+
image_embeddings_client: Optional[ImageEmbeddings] = None,
6565
image_blob_container_client: Optional[ContainerClient] = None,
6666
image_datalake_client: Optional[FileSystemClient] = None,
6767
):
@@ -72,8 +72,7 @@ def __init__(
7272
self.agent_client = agent_client
7373
self.openai_client = openai_client
7474
self.auth_helper = auth_helper
75-
self.image_blob_container_client = image_blob_container_client
76-
self.image_datalake_client = image_datalake_client
75+
7776
self.chatgpt_model = chatgpt_model
7877
self.chatgpt_deployment = chatgpt_deployment
7978
self.embedding_deployment = embedding_deployment
@@ -90,9 +89,10 @@ def __init__(
9089
self.answer_prompt = self.prompt_manager.load_prompt("chat_answer_question.prompty")
9190
self.reasoning_effort = reasoning_effort
9291
self.include_token_usage = True
93-
self.vision_endpoint = vision_endpoint
94-
self.vision_token_provider = vision_token_provider
9592
self.multimodal_enabled = multimodal_enabled
93+
self.image_embeddings_client = image_embeddings_client
94+
self.image_blob_container_client = image_blob_container_client
95+
self.image_datalake_client = image_datalake_client
9696

9797
def get_search_query(self, chat_completion: ChatCompletion, user_query: str):
9898
response_message = chat_completion.choices[0].message
@@ -340,7 +340,7 @@ async def run_search_approach(
340340
if use_vector_search:
341341
vectors.append(await self.compute_text_embedding(query_text))
342342
if use_image_embeddings:
343-
vectors.append(await self.compute_image_embedding(query_text))
343+
vectors.append(await self.compute_multimodal_embedding(query_text))
344344

345345
results = await self.search(
346346
top,

app/backend/approaches/retrievethenread.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from collections.abc import Awaitable
2-
from typing import Any, Callable, Optional, cast
1+
from typing import Any, Optional, cast
32

43
from azure.search.documents.agent.aio import KnowledgeAgentRetrievalClient
54
from azure.search.documents.aio import SearchClient
@@ -19,6 +18,7 @@
1918
)
2019
from approaches.promptmanager import PromptManager
2120
from core.authentication import AuthenticationHelper
21+
from prepdocslib.embeddings import ImageEmbeddings
2222

2323

2424
class RetrieveThenReadApproach(Approach):
@@ -51,8 +51,7 @@ def __init__(
5151
prompt_manager: PromptManager,
5252
reasoning_effort: Optional[str] = None,
5353
multimodal_enabled: bool = False,
54-
vision_endpoint: Optional[str] = None,
55-
vision_token_provider: Optional[Callable[[], Awaitable[str]]] = None,
54+
image_embeddings_client: Optional[ImageEmbeddings] = None,
5655
image_blob_container_client: Optional[ContainerClient] = None,
5756
image_datalake_client: Optional[FileSystemClient] = None,
5857
):
@@ -64,8 +63,6 @@ def __init__(
6463
self.chatgpt_deployment = chatgpt_deployment
6564
self.openai_client = openai_client
6665
self.auth_helper = auth_helper
67-
self.image_blob_container_client = image_blob_container_client
68-
self.image_datalake_client = image_datalake_client
6966
self.chatgpt_model = chatgpt_model
7067
self.embedding_model = embedding_model
7168
self.embedding_dimensions = embedding_dimensions
@@ -80,9 +77,10 @@ def __init__(
8077
self.answer_prompt = self.prompt_manager.load_prompt("ask_answer_question.prompty")
8178
self.reasoning_effort = reasoning_effort
8279
self.include_token_usage = True
83-
self.vision_endpoint = vision_endpoint
84-
self.vision_token_provider = vision_token_provider
8580
self.multimodal_enabled = multimodal_enabled
81+
self.image_embeddings_client = image_embeddings_client
82+
self.image_blob_container_client = image_blob_container_client
83+
self.image_datalake_client = image_datalake_client
8684

8785
async def run(
8886
self,
@@ -186,7 +184,7 @@ async def run_search_approach(
186184
if vector_fields_enum != VectorFieldType.IMAGE_EMBEDDING:
187185
vectors.append(await self.compute_text_embedding(q))
188186
if use_image_embeddings:
189-
vectors.append(await self.compute_image_embedding(q))
187+
vectors.append(await self.compute_multimodal_embedding(q))
190188

191189
results = await self.search(
192190
top,

app/backend/prepdocslib/embeddings.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def __init__(self, endpoint: str, token_provider: Callable[[], Awaitable[str]]):
236236
self.token_provider = token_provider
237237
self.endpoint = endpoint
238238

239-
async def create_embedding(self, image_bytes: bytes) -> list[float]:
239+
async def create_embedding_for_image(self, image_bytes: bytes) -> list[float]:
240240
endpoint = urljoin(self.endpoint, "computervision/retrieval:vectorizeImage")
241241
params = {"api-version": "2024-02-01", "model-version": "2023-04-15"}
242242
headers = {"Authorization": "Bearer " + await self.token_provider()}
@@ -254,5 +254,25 @@ async def create_embedding(self, image_bytes: bytes) -> list[float]:
254254
return resp_json["vector"]
255255
raise ValueError("Failed to get image embedding after multiple retries.")
256256

257+
async def create_embedding_for_text(self, q: str):
258+
if not self.endpoint:
259+
raise ValueError("Azure AI Vision endpoint must be set to compute image embedding.")
260+
endpoint = urljoin(self.endpoint, "computervision/retrieval:vectorizeText")
261+
headers = {"Content-Type": "application/json"}
262+
params = {"api-version": "2024-02-01", "model-version": "2023-04-15"}
263+
data = {"text": q}
264+
265+
if not self.token_provider:
266+
raise ValueError("Azure AI Vision token provider must be set to compute image embedding.")
267+
headers["Authorization"] = "Bearer " + await self.token_provider()
268+
269+
async with aiohttp.ClientSession() as session:
270+
async with session.post(
271+
url=endpoint, params=params, headers=headers, json=data, raise_for_status=True
272+
) as response:
273+
json = await response.json()
274+
return json["vector"]
275+
raise ValueError("Failed to get image embedding after multiple retries.")
276+
257277
def before_retry_sleep(self, retry_state):
258278
logger.info("Rate limited on the Vision embeddings API, sleeping before retrying...")

app/backend/prepdocslib/filestrategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def parse_file(
3838
file.filename(), image.bytes, image.filename, image.page_num, user_oid=user_oid
3939
)
4040
if image_embeddings_client:
41-
image.embedding = await image_embeddings_client.create_embedding(image.bytes)
41+
image.embedding = await image_embeddings_client.create_embedding_for_image(image.bytes)
4242
logger.info("Splitting '%s' into sections", file.filename())
4343
sections = [
4444
Section(split_page, content=file, category=category) for split_page in processor.splitter.split_pages(pages)

tests/test_prepdocs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ async def test_image_embeddings_success(mock_azurehttp_calls):
232232

233233
# Call the create_embedding method with fake image bytes
234234
image_bytes = b"fake_image_data"
235-
embedding = await image_embeddings.create_embedding(image_bytes)
235+
embedding = await image_embeddings.create_embedding_for_image(image_bytes)
236236

237237
# Verify the result
238238
assert embedding == [

todo.txt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@ TODO:
33
* Test with integrated vectorization
44
* Multivector is working
55
* Can we get images mapped??
6+
* We need DocIntelligence skill
67
* Update all TODOs in the code/docs
78
* Fix/add unit tests - check coverage
8-
* In conftest, should I make a new env for vision? Currently I mashed it into the existing env, but it might be cleaner to have a separate one, as now I have to pass llm_inputs explicitly in the tests to turn off image responses.
9-
* vote: make a new env
109
* LLMInputType and VectorFields have inconsistently named values
1110
* # Vector fields:
1211
# [X] text embedding field (embedding3) use_text_vector=True
@@ -16,9 +15,6 @@ TODO:
1615
# [X] text sources , use_text_sources = True
1716
# [X] image sources , use_image_sources = True
1817

19-
* Should we make an Azure AI Vision client class? So we dont have to pass two things around, just one?
20-
* vision_endpoint and vision_token_provider
21-
* we have one! its in embeddings.py, add compute_image_embeddings method to it and use it instead
2218

2319
To decide:
2420
* For user data lake client, how often should we double check the ACL matches the oid, versus assuming the URLs convey that? (Like when fetching the image?)

0 commit comments

Comments
 (0)