Skip to content

Commit 2a73065

Browse files
committed
Getting image citations almost working
1 parent e85f8c5 commit 2a73065

File tree

9 files changed

+114
-102
lines changed

9 files changed

+114
-102
lines changed

app/backend/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ async def setup_clients():
686686
agent_client=agent_client,
687687
openai_client=openai_client,
688688
auth_helper=auth_helper,
689+
images_blob_container_client=image_blob_container_client,
689690
chatgpt_model=OPENAI_CHATGPT_MODEL,
690691
chatgpt_deployment=AZURE_OPENAI_CHATGPT_DEPLOYMENT,
691692
embedding_model=OPENAI_EMB_MODEL,

app/backend/approaches/approach.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Document:
5050
score: Optional[float] = None
5151
reranker_score: Optional[float] = None
5252
search_agent_query: Optional[str] = None
53+
images: Optional[list[dict[str, Any]]] = None
5354

5455
def serialize_for_results(self) -> dict[str, Any]:
5556
result_dict = {
@@ -75,6 +76,7 @@ def serialize_for_results(self) -> dict[str, Any]:
7576
"score": self.score,
7677
"reranker_score": self.reranker_score,
7778
"search_agent_query": self.search_agent_query,
79+
"images": self.images,
7880
}
7981
return result_dict
8082

@@ -238,6 +240,7 @@ async def search(
238240
captions=cast(list[QueryCaptionResult], document.get("@search.captions")),
239241
score=document.get("@search.score"),
240242
reranker_score=document.get("@search.reranker_score"),
243+
images=document.get("images"),
241244
)
242245
)
243246

app/backend/approaches/prompts/ask_answer_question.prompty

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ Use 'you' to refer to the individual asking the questions even if they ask with
1919
Answer the following question using only the data provided in the sources below.
2020
Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response.
2121
If you cannot answer using the sources below, say you don't know. Use below example to answer.
22-
{% if use_images %}
23-
Each image source has the file name in the top left corner of the image with coordinates (10,10) pixels and is in the format SourceFileName:<file_name>.
22+
{% if image_sources %}
23+
Each image source has the original document file name in the top left corner of the image with coordinates (10,10) pixels and is in the format Document:<document_name.ext#page=N>.
24+
The filename of the actual image is in the top right corner of the image and is in the format Figure:<image_name.png>.
2425
Each text source starts in a new line and has the file name followed by colon and the actual information.
25-
Always include the source name from the image or text for each fact you use in the response in the format: [filename].
26+
Always include the source document filename for each fact you use in the response in the format: [document_name.ext#page=N].
27+
If you are referencing an image, add the image filename in the format: [document_name.ext#page=N(image_name.png)].
2628
Answer the following question using only the data provided in the sources below.
2729
The text and image source can be the same file name, don't use the image title when citing the image source, only use the file name as mentioned.
2830
If you cannot answer using the sources below, say you don't know. Return just the answer without any input texts.
@@ -50,6 +52,5 @@ user:
5052
{% if text_sources is defined %}
5153
Sources:
5254
{% for text_source in text_sources %}
53-
{{ text_source }}
5455
{% endfor %}
5556
{% endif %}

app/backend/approaches/retrievethenread.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from azure.search.documents.agent.aio import KnowledgeAgentRetrievalClient
44
from azure.search.documents.aio import SearchClient
55
from azure.search.documents.models import VectorQuery
6+
from azure.storage.blob.aio import ContainerClient
67
from openai import AsyncOpenAI
78
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
89

910
from approaches.approach import Approach, DataPoints, ExtraInfo, ThoughtStep
1011
from approaches.promptmanager import PromptManager
1112
from core.authentication import AuthenticationHelper
13+
from core.imageshelper import download_blob_as_base64
1214

1315

1416
class RetrieveThenReadApproach(Approach):
@@ -27,6 +29,7 @@ def __init__(
2729
agent_deployment: Optional[str],
2830
agent_client: KnowledgeAgentRetrievalClient,
2931
auth_helper: AuthenticationHelper,
32+
images_blob_container_client: ContainerClient,
3033
openai_client: AsyncOpenAI,
3134
chatgpt_model: str,
3235
chatgpt_deployment: Optional[str], # Not needed for non-Azure OpenAI
@@ -49,6 +52,7 @@ def __init__(
4952
self.chatgpt_deployment = chatgpt_deployment
5053
self.openai_client = openai_client
5154
self.auth_helper = auth_helper
55+
self.images_blob_container_client = images_blob_container_client
5256
self.chatgpt_model = chatgpt_model
5357
self.embedding_model = embedding_model
5458
self.embedding_dimensions = embedding_dimensions
@@ -86,7 +90,11 @@ async def run(
8690
messages = self.prompt_manager.render_prompt(
8791
self.answer_prompt,
8892
self.get_system_prompt_variables(overrides.get("prompt_template"))
89-
| {"user_query": q, "text_sources": extra_info.data_points.text},
93+
| {
94+
"user_query": q,
95+
"text_sources": extra_info.data_points.text,
96+
"image_sources": extra_info.data_points.images,
97+
},
9098
)
9199

92100
chat_completion = cast(
@@ -126,6 +134,7 @@ async def run_search_approach(
126134
use_semantic_ranker = True if overrides.get("semantic_ranker") else False
127135
use_query_rewriting = True if overrides.get("query_rewriting") else False
128136
use_semantic_captions = True if overrides.get("semantic_captions") else False
137+
use_multimodal = True # TODO: if overrides.get("use_multimodal") else False
129138
top = overrides.get("top", 3)
130139
minimum_search_score = overrides.get("minimum_search_score", 0.0)
131140
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
@@ -137,6 +146,11 @@ async def run_search_approach(
137146
if use_vector_search:
138147
vectors.append(await self.compute_text_embedding(q))
139148

149+
# If multimodal is enabled, also compute image embeddings
150+
# TODO: will this work with agentic? is this doing multivector search correctly?
151+
# if use_multimodal:
152+
# vectors.append(await self.compute_image_embedding(q))
153+
140154
results = await self.search(
141155
top,
142156
q,
@@ -151,10 +165,26 @@ async def run_search_approach(
151165
use_query_rewriting,
152166
)
153167

154-
text_sources = self.get_sources_content(results, use_semantic_captions, use_image_citation=False)
168+
text_sources = self.get_sources_content(results, use_semantic_captions, use_image_citation=use_multimodal)
169+
170+
# Extract unique image URLs from results if multimodal is enabled
171+
172+
seen_urls = set()
173+
image_sources = []
174+
if use_multimodal:
175+
for doc in results:
176+
if hasattr(doc, "images") and doc.images:
177+
for img in doc.images:
178+
# Skip if we've already processed this URL
179+
if img["url"] in seen_urls:
180+
continue
181+
seen_urls.add(img["url"])
182+
url = await download_blob_as_base64(self.images_blob_container_client, img["url"])
183+
if url:
184+
image_sources.append(url)
155185

156186
return ExtraInfo(
157-
DataPoints(text=text_sources),
187+
DataPoints(text=text_sources, images=image_sources),
158188
thoughts=[
159189
ThoughtStep(
160190
"Search using user query",
@@ -167,6 +197,7 @@ async def run_search_approach(
167197
"filter": filter,
168198
"use_vector_search": use_vector_search,
169199
"use_text_search": use_text_search,
200+
"use_multimodal": use_multimodal,
170201
},
171202
),
172203
ThoughtStep(

app/backend/core/imageshelper.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import base64
22
import logging
3-
import os
43
from typing import Optional
54

65
from azure.core.exceptions import ResourceNotFoundError
76
from azure.storage.blob.aio import ContainerClient
87
from typing_extensions import Literal, Required, TypedDict
98

10-
from approaches.approach import Document
11-
129

1310
class ImageURL(TypedDict, total=False):
1411
url: Required[str]
@@ -18,23 +15,30 @@ class ImageURL(TypedDict, total=False):
1815
"""Specifies the detail level of the image."""
1916

2017

21-
async def download_blob_as_base64(blob_container_client: ContainerClient, file_path: str) -> Optional[str]:
22-
base_name, _ = os.path.splitext(file_path)
23-
image_filename = base_name + ".png"
18+
async def download_blob_as_base64(blob_container_client: ContainerClient, blob_url: str) -> Optional[str]:
2419
try:
25-
blob = await blob_container_client.get_blob_client(image_filename).download_blob()
20+
# Handle full URLs
21+
if blob_url.startswith("http"):
22+
# Extract blob path from full URL
23+
# URL format: https://{account}.blob.core.windows.net/{container}/{blob_path}
24+
url_parts = blob_url.split("/")
25+
# Skip the domain parts and container name to get the blob path
26+
blob_path = "/".join(url_parts[4:])
27+
else:
28+
# Treat as a direct blob path
29+
blob_path = blob_url
30+
31+
# Download the blob
32+
blob = await blob_container_client.get_blob_client(blob_path).download_blob()
2633
if not blob.properties:
27-
logging.warning(f"No blob exists for {image_filename}")
34+
logging.warning(f"No blob exists for {blob_path}")
2835
return None
36+
2937
img = base64.b64encode(await blob.readall()).decode("utf-8")
3038
return f"data:image/png;base64,{img}"
3139
except ResourceNotFoundError:
32-
logging.warning(f"No blob exists for {image_filename}")
40+
logging.warning(f"No blob exists for {blob_path}")
41+
return None
42+
except Exception as e:
43+
logging.error(f"Error downloading blob {blob_url}: {str(e)}")
3344
return None
34-
35-
36-
async def fetch_image(blob_container_client: ContainerClient, result: Document) -> Optional[str]:
37-
if result.sourcepage:
38-
img = await download_blob_as_base64(blob_container_client, result.sourcepage)
39-
return img
40-
return None

app/backend/prepdocslib/blobmanager.py

Lines changed: 42 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
1-
import datetime
21
import io
32
import logging
43
import os
54
import re
65
from typing import Optional, Union
76

8-
import pymupdf
97
from azure.core.credentials_async import AsyncTokenCredential
108
from azure.storage.blob import (
11-
BlobSasPermissions,
129
UserDelegationKey,
13-
generate_blob_sas,
1410
)
15-
from azure.storage.blob.aio import BlobServiceClient, ContainerClient
11+
from azure.storage.blob.aio import BlobServiceClient
1612
from PIL import Image, ImageDraw, ImageFont
17-
from pypdf import PdfReader
1813

1914
from .listfilestrategy import File
2015

@@ -64,7 +59,7 @@ async def upload_blob(self, file: File) -> Optional[list[str]]:
6459
return None
6560

6661
async def upload_document_image(
67-
self, document_file: File, image_bytes: bytes, image_filename: str
62+
self, document_file: File, image_bytes: bytes, image_filename: str, image_page_num: int
6863
) -> Optional[str]:
6964
if self.image_container is None:
7065
raise ValueError(
@@ -75,81 +70,55 @@ async def upload_document_image(
7570
) as service_client, service_client.get_container_client(self.image_container) as container_client:
7671
if not await container_client.exists():
7772
await container_client.create_container()
78-
blob_name = BlobManager.blob_name_from_file_name(document_file.content.name) + "/" + image_filename
79-
logger.info("Uploading blob for document image %s", blob_name)
80-
blob_client = await container_client.upload_blob(blob_name, io.BytesIO(image_bytes), overwrite=True)
81-
return blob_client.url
82-
return None
83-
84-
def get_managedidentity_connectionstring(self):
85-
return f"ResourceId=/subscriptions/{self.subscriptionId}/resourceGroups/{self.resourceGroup}/providers/Microsoft.Storage/storageAccounts/{self.account};"
86-
87-
async def upload_pdf_blob_images(
88-
self, service_client: BlobServiceClient, container_client: ContainerClient, file: File
89-
) -> list[str]:
90-
with open(file.content.name, "rb") as reopened_file:
91-
reader = PdfReader(reopened_file)
92-
page_count = len(reader.pages)
93-
doc = pymupdf.open(file.content.name)
94-
sas_uris = []
95-
start_time = datetime.datetime.now(datetime.timezone.utc)
96-
expiry_time = start_time + datetime.timedelta(days=1)
97-
98-
font = None
99-
try:
100-
font = ImageFont.truetype("arial.ttf", 20)
101-
except OSError:
102-
try:
103-
font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf", 20)
104-
except OSError:
105-
logger.info("Unable to find arial.ttf or FreeMono.ttf, using default font")
106-
107-
for i in range(page_count):
108-
blob_name = BlobManager.blob_image_name_from_file_page(file.content.name, i)
109-
logger.info("Converting page %s to image and uploading -> %s", i, blob_name)
110-
111-
doc = pymupdf.open(file.content.name)
112-
page = doc.load_page(i)
113-
pix = page.get_pixmap()
114-
original_img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) # type: ignore
11573

116-
# Create a new image with additional space for text
117-
text_height = 40 # Height of the text area
118-
new_img = Image.new("RGB", (original_img.width, original_img.height + text_height), "white")
74+
# Load and modify the image to add text
75+
image = Image.open(io.BytesIO(image_bytes))
76+
text_height = 40
77+
new_img = Image.new("RGB", (image.width, image.height + text_height), "white")
78+
new_img.paste(image, (0, text_height))
11979

120-
# Paste the original image onto the new image
121-
new_img.paste(original_img, (0, text_height))
122-
123-
# Draw the text on the white area
80+
# Add text
12481
draw = ImageDraw.Draw(new_img)
125-
text = f"SourceFileName:{blob_name}"
82+
sourcepage = BlobManager.sourcepage_from_file_page(document_file.content.name, page=image_page_num)
83+
text = f"Document: {sourcepage}"
12684

127-
# 10 pixels from the top and left of the image
128-
x = 10
129-
y = 10
130-
draw.text((x, y), text, font=font, fill="black")
85+
font = None
86+
try:
87+
font = ImageFont.truetype("arial.ttf", 24)
88+
except OSError:
89+
try:
90+
font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf", 24)
91+
except OSError:
92+
logger.info("Unable to find arial.ttf or FreeMono.ttf, using default font")
93+
94+
# Draw document text on left
95+
draw.text((10, 10), text, font=font, fill="black")
96+
97+
# Draw figure text on right
98+
figure_text = f"Figure: {image_filename}"
99+
if font:
100+
# Get the width of the text to position it on the right
101+
text_width = draw.textlength(figure_text, font=font)
102+
draw.text((new_img.width - text_width - 10, 10), figure_text, font=font, fill="black")
103+
else:
104+
# If no font available, make a best effort to position on right
105+
draw.text((new_img.width - 200, 10), figure_text, font=font, fill="black")
131106

107+
# Convert back to bytes
132108
output = io.BytesIO()
133-
new_img.save(output, format="PNG")
109+
new_img.save(output, format=image.format or "PNG")
134110
output.seek(0)
135111

112+
blob_name = (
113+
f"{self.blob_name_from_file_name(document_file.content.name)}/page{image_page_num}/{image_filename}"
114+
)
115+
logger.info("Uploading blob for document image %s", blob_name)
136116
blob_client = await container_client.upload_blob(blob_name, output, overwrite=True)
137-
if not self.user_delegation_key:
138-
self.user_delegation_key = await service_client.get_user_delegation_key(start_time, expiry_time)
139-
140-
if blob_client.account_name is not None:
141-
sas_token = generate_blob_sas(
142-
account_name=blob_client.account_name,
143-
container_name=blob_client.container_name,
144-
blob_name=blob_client.blob_name,
145-
user_delegation_key=self.user_delegation_key,
146-
permission=BlobSasPermissions(read=True),
147-
expiry=expiry_time,
148-
start=start_time,
149-
)
150-
sas_uris.append(f"{blob_client.url}?{sas_token}")
151-
152-
return sas_uris
117+
return blob_client.url
118+
return None
119+
120+
def get_managedidentity_connectionstring(self):
121+
return f"ResourceId=/subscriptions/{self.subscriptionId}/resourceGroups/{self.resourceGroup}/providers/Microsoft.Storage/storageAccounts/{self.account};"
153122

154123
async def remove_blob(self, path: Optional[str] = None):
155124
async with BlobServiceClient(

app/backend/prepdocslib/filestrategy.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ async def parse_file(
3131
for page in pages:
3232
for image in page.images:
3333
if image.url is None:
34-
image.url = await blob_manager.upload_document_image(file, image.bytes, image.filename)
34+
image.url = await blob_manager.upload_document_image(file, image.bytes, image.filename, image.page_num)
3535
if image_embeddings_client:
3636
image.embedding = await image_embeddings_client.create_embedding(image.bytes)
3737
logger.info("Splitting '%s' into sections", file.filename())
@@ -43,9 +43,6 @@ async def parse_file(
4343
section.split_page.images = [
4444
image for page in pages if page.page_num == section.split_page.page_num for image in page.images
4545
]
46-
logger.info(
47-
"Section for page %d has %d images", section.split_page.page_num, len(section.split_page.images)
48-
)
4946
return sections
5047

5148

@@ -115,7 +112,9 @@ async def run(self):
115112
files = self.list_file_strategy.list()
116113
async for file in files:
117114
try:
118-
sections = await parse_file(file, self.file_processors, self.category, self.blob_manager, self.image_embeddings)
115+
sections = await parse_file(
116+
file, self.file_processors, self.category, self.blob_manager, self.image_embeddings
117+
)
119118
if sections:
120119
await self.search_manager.update_content(sections, url=file.url)
121120
finally:

0 commit comments

Comments
 (0)