Skip to content

Commit 7c8f825

Browse files
committed
Fix media description with OpenAI
1 parent 001c86f commit 7c8f825

File tree

8 files changed

+134
-49
lines changed

8 files changed

+134
-49
lines changed

app/backend/app.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
CONFIG_USER_BLOB_CONTAINER_CLIENT,
8585
CONFIG_USER_UPLOAD_ENABLED,
8686
CONFIG_VECTOR_SEARCH_ENABLED,
87+
CONFIG_MULTIMODAL_ENABLED
8788
)
8889
from core.authentication import AuthenticationHelper
8990
from core.sessionhelper import create_session_id
@@ -705,6 +706,8 @@ async def setup_clients():
705706
query_speller=AZURE_SEARCH_QUERY_SPELLER,
706707
prompt_manager=prompt_manager,
707708
reasoning_effort=OPENAI_REASONING_EFFORT,
709+
vision_endpoint=AZURE_VISION_ENDPOINT,
710+
vision_token_provider=token_provider,
708711
)
709712

710713
@bp.after_app_serving
@@ -734,7 +737,7 @@ def create_app():
734737

735738
# Log levels should be one of https://docs.python.org/3/library/logging.html#logging-levels
736739
# Set root level to WARNING to avoid seeing overly verbose logs from SDKS
737-
logging.basicConfig(level=logging.WARNING)
740+
logging.basicConfig(level=logging.DEBUG)
738741
# Set our own logger levels to INFO by default
739742
app_level = os.getenv("APP_LOG_LEVEL", "INFO")
740743
app.logger.setLevel(os.getenv("APP_LOG_LEVEL", app_level))

app/backend/approaches/chatreadretrieveread.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Awaitable
2-
from typing import Any, Optional, Union, cast
2+
from typing import Any, Optional, Union, cast, Callable
33

44
from azure.search.documents.agent.aio import KnowledgeAgentRetrievalClient
55
from azure.search.documents.aio import SearchClient
@@ -47,6 +47,8 @@ def __init__(
4747
query_speller: str,
4848
prompt_manager: PromptManager,
4949
reasoning_effort: Optional[str] = None,
50+
vision_endpoint: Optional[str] = None,
51+
vision_token_provider: Callable[[], Awaitable[str]],
5052
):
5153
self.search_client = search_client
5254
self.search_index_name = search_index_name
@@ -71,6 +73,8 @@ def __init__(
7173
self.answer_prompt = self.prompt_manager.load_prompt("chat_answer_question.prompty")
7274
self.reasoning_effort = reasoning_effort
7375
self.include_token_usage = True
76+
self.vision_endpoint = vision_endpoint
77+
self.vision_token_provider = vision_token_provider
7478

7579
async def run_until_final_call(
7680
self,

app/backend/prepdocs.py

Lines changed: 107 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from prepdocslib.strategy import DocumentAction, SearchInfo, Strategy
3636
from prepdocslib.textparser import TextParser
3737
from prepdocslib.textsplitter import SentenceTextSplitter, SimpleTextSplitter
38+
from enum import Enum
3839

3940
logger = logging.getLogger("scripts")
4041

@@ -126,15 +127,23 @@ def setup_list_file_strategy(
126127
return list_file_strategy
127128

128129

130+
class OpenAIHost(str, Enum):
131+
OPENAI = "openai"
132+
AZURE = "azure"
133+
AZURE_CUSTOM = "azure_custom"
134+
LOCAL = "local"
135+
136+
129137
def setup_embeddings_service(
130138
azure_credential: AsyncTokenCredential,
131-
openai_host: str,
132-
openai_model_name: str,
133-
openai_service: Union[str, None],
134-
openai_custom_url: Union[str, None],
135-
openai_deployment: Union[str, None],
136-
openai_dimensions: int,
137-
openai_api_version: str,
139+
openai_host: OpenAIHost,
140+
emb_model_name: str,
141+
emb_model_dimensions: int,
142+
azure_openai_service: Union[str, None],
143+
azure_openai_custom_url: Union[str, None],
144+
azure_openai_deployment: Union[str, None],
145+
azure_openai_key: Union[str, None],
146+
azure_openai_api_version: str,
138147
openai_key: Union[str, None],
139148
openai_org: Union[str, None],
140149
disable_vectors: bool = False,
@@ -144,31 +153,83 @@ def setup_embeddings_service(
144153
logger.info("Not setting up embeddings service")
145154
return None
146155

147-
if openai_host != "openai":
156+
if openai_host in [OpenAIHost.AZURE, OpenAIHost.AZURE_CUSTOM]:
148157
azure_open_ai_credential: Union[AsyncTokenCredential, AzureKeyCredential] = (
149-
azure_credential if openai_key is None else AzureKeyCredential(openai_key)
158+
azure_credential if azure_openai_key is None else AzureKeyCredential(azure_openai_key)
150159
)
151160
return AzureOpenAIEmbeddingService(
152-
open_ai_service=openai_service,
153-
open_ai_custom_url=openai_custom_url,
154-
open_ai_deployment=openai_deployment,
155-
open_ai_model_name=openai_model_name,
156-
open_ai_dimensions=openai_dimensions,
157-
open_ai_api_version=openai_api_version,
161+
open_ai_service=azure_openai_service,
162+
open_ai_custom_url=azure_openai_custom_url,
163+
open_ai_deployment=azure_openai_deployment,
164+
open_ai_model_name=emb_model_name,
165+
open_ai_dimensions=emb_model_dimensions,
166+
open_ai_api_version=azure_openai_api_version,
158167
credential=azure_open_ai_credential,
159168
disable_batch=disable_batch_vectors,
160169
)
161170
else:
162171
if openai_key is None:
163172
raise ValueError("OpenAI key is required when using the non-Azure OpenAI API")
164173
return OpenAIEmbeddingService(
165-
open_ai_model_name=openai_model_name,
166-
open_ai_dimensions=openai_dimensions,
174+
open_ai_model_name=emb_model_name,
175+
open_ai_dimensions=emb_model_dimensions,
167176
credential=openai_key,
168177
organization=openai_org,
169178
disable_batch=disable_batch_vectors,
170179
)
171180

181+
def setup_openai_client(
182+
openai_host: OpenAIHost,
183+
azure_openai_api_key: Union[str, None] = None,
184+
azure_openai_api_version: Union[str, None] = None,
185+
azure_openai_service: Union[str, None] = None,
186+
azure_openai_custom_url: Union[str, None] = None,
187+
azure_credential: AsyncTokenCredential = None,
188+
openai_api_key: Union[str, None] = None,
189+
openai_organization: Union[str, None] = None,
190+
):
191+
if openai_host not in OpenAIHost:
192+
raise ValueError(f"Invalid OPENAI_HOST value: {openai_host}. Must be one of {[h.value for h in OpenAIHost]}.")
193+
194+
if openai_host in [OpenAIHost.AZURE, OpenAIHost.AZURE_CUSTOM]:
195+
if openai_host == OpenAIHost.AZURE_CUSTOM:
196+
logger.info("OPENAI_HOST is azure_custom, setting up Azure OpenAI custom client")
197+
if not azure_openai_custom_url:
198+
raise ValueError("AZURE_OPENAI_CUSTOM_URL must be set when OPENAI_HOST is azure_custom")
199+
endpoint = azure_openai_custom_url
200+
else:
201+
logger.info("OPENAI_HOST is azure, setting up Azure OpenAI client")
202+
if not azure_openai_service:
203+
raise ValueError("AZURE_OPENAI_SERVICE must be set when OPENAI_HOST is azure")
204+
endpoint = f"https://{azure_openai_service}.openai.azure.com"
205+
if azure_openai_api_key:
206+
logger.info("AZURE_OPENAI_API_KEY_OVERRIDE found, using as api_key for Azure OpenAI client")
207+
openai_client = AsyncAzureOpenAI(
208+
api_version=azure_openai_api_version, azure_endpoint=endpoint, api_key=azure_openai_api_key
209+
)
210+
else:
211+
logger.info("Using Azure credential (passwordless authentication) for Azure OpenAI client")
212+
token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")
213+
openai_client = AsyncAzureOpenAI(
214+
api_version=azure_openai_api_version,
215+
azure_endpoint=endpoint,
216+
azure_ad_token_provider=token_provider,
217+
)
218+
elif openai_host == OpenAIHost.LOCAL:
219+
logger.info("OPENAI_HOST is local, setting up local OpenAI client for OPENAI_BASE_URL with no key")
220+
openai_client = AsyncOpenAI(
221+
base_url=os.environ["OPENAI_BASE_URL"],
222+
api_key="no-key-required",
223+
)
224+
else:
225+
logger.info(
226+
"OPENAI_HOST is not azure, setting up OpenAI client using OPENAI_API_KEY and OPENAI_ORGANIZATION environment variables"
227+
)
228+
openai_client = AsyncOpenAI(
229+
api_key=openai_api_key,
230+
organization=openai_organization,
231+
)
232+
return openai_client
172233

173234
def setup_file_processors(
174235
azure_credential: AsyncTokenCredential,
@@ -194,7 +255,7 @@ def setup_file_processors(
194255
doc_int_parser = DocumentAnalysisParser(
195256
endpoint=f"https://{document_intelligence_service}.cognitiveservices.azure.com/",
196257
credential=documentintelligence_creds,
197-
media_description_strategy = "openai" if use_multimodal else "contentunderstanding" if use_content_understanding else "none",
258+
media_description_strategy = MediaDescriptionStrategy.OPENAI if use_multimodal else MediaDescriptionStrategy.CONTENTUNDERSTANDING if use_content_understanding else MediaDescriptionStrategy.NONE,
198259
openai_client=openai_client,
199260
openai_model=openai_model,
200261
openai_deployment=openai_deployment,
@@ -323,7 +384,7 @@ async def main(strategy: Strategy, setup_index: bool = True):
323384
args = parser.parse_args()
324385

325386
if args.verbose:
326-
logging.basicConfig(format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)])
387+
logging.basicConfig(format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)], level=logging.WARNING)
327388
# We only set the level to INFO for our logger,
328389
# to avoid seeing the noisy INFO level logs from the Azure SDKs
329390
logger.setLevel(logging.DEBUG)
@@ -397,31 +458,38 @@ async def main(strategy: Strategy, setup_index: bool = True):
397458
datalake_key=clean_key_if_exists(args.datalakekey),
398459
)
399460

400-
openai_host = os.environ["OPENAI_HOST"]
401-
openai_key = None
402-
if os.getenv("AZURE_OPENAI_API_KEY_OVERRIDE"):
403-
openai_key = os.getenv("AZURE_OPENAI_API_KEY_OVERRIDE")
404-
elif not openai_host.startswith("azure") and os.getenv("OPENAI_API_KEY"):
405-
openai_key = os.getenv("OPENAI_API_KEY")
406-
407-
openai_dimensions = 1536
461+
openai_host = OpenAIHost(os.environ["OPENAI_HOST"])
462+
# https://learn.microsoft.com/azure/ai-services/openai/api-version-deprecation#latest-ga-api-release
463+
azure_openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION") or "2024-06-01"
464+
emb_model_dimensions = 1536
408465
if os.getenv("AZURE_OPENAI_EMB_DIMENSIONS"):
409-
openai_dimensions = int(os.environ["AZURE_OPENAI_EMB_DIMENSIONS"])
466+
emb_model_dimensions = int(os.environ["AZURE_OPENAI_EMB_DIMENSIONS"])
410467
openai_embeddings_service = setup_embeddings_service(
411468
azure_credential=azd_credential,
412469
openai_host=openai_host,
413-
openai_model_name=os.environ["AZURE_OPENAI_EMB_MODEL_NAME"],
414-
openai_service=os.getenv("AZURE_OPENAI_SERVICE"),
415-
openai_custom_url=os.getenv("AZURE_OPENAI_CUSTOM_URL"),
416-
openai_deployment=os.getenv("AZURE_OPENAI_EMB_DEPLOYMENT"),
417-
# https://learn.microsoft.com/azure/ai-services/openai/api-version-deprecation#latest-ga-api-release
418-
openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION") or "2024-06-01",
419-
openai_dimensions=openai_dimensions,
420-
openai_key=clean_key_if_exists(openai_key),
470+
emb_model_name=os.environ["AZURE_OPENAI_EMB_MODEL_NAME"],
471+
emb_model_dimensions=emb_model_dimensions,
472+
azure_openai_service=os.getenv("AZURE_OPENAI_SERVICE"),
473+
azure_openai_custom_url=os.getenv("AZURE_OPENAI_CUSTOM_URL"),
474+
azure_openai_deployment=os.getenv("AZURE_OPENAI_EMB_DEPLOYMENT"),
475+
azure_openai_api_version=azure_openai_api_version,
476+
azure_openai_key=os.getenv("AZURE_OPENAI_API_KEY_OVERRIDE"),
477+
openai_key=clean_key_if_exists(os.getenv("OPENAI_API_KEY")),
421478
openai_org=os.getenv("OPENAI_ORGANIZATION"),
422479
disable_vectors=dont_use_vectors,
423480
disable_batch_vectors=args.disablebatchvectors,
424481
)
482+
openai_client = setup_openai_client(
483+
openai_host=openai_host,
484+
azure_openai_api_version=azure_openai_api_version,
485+
azure_openai_service=os.getenv("AZURE_OPENAI_SERVICE"),
486+
azure_openai_custom_url=os.getenv("AZURE_OPENAI_CUSTOM_URL"),
487+
azure_openai_api_key=os.getenv("AZURE_OPENAI_API_KEY_OVERRIDE"),
488+
azure_credential=azd_credential,
489+
openai_api_key=clean_key_if_exists(os.getenv("OPENAI_API_KEY")),
490+
openai_organization=os.getenv("OPENAI_ORGANIZATION"),
491+
)
492+
425493

426494
ingestion_strategy: Strategy
427495
if use_int_vectorization:
@@ -452,6 +520,9 @@ async def main(strategy: Strategy, setup_index: bool = True):
452520
use_content_understanding=use_content_understanding,
453521
use_multimodal=use_multimodal,
454522
content_understanding_endpoint=os.getenv("AZURE_CONTENTUNDERSTANDING_ENDPOINT"),
523+
openai_client=openai_client,
524+
openai_model=os.getenv("AZURE_OPENAI_CHATGPT_MODEL"),
525+
openai_deployment=os.getenv("AZURE_OPENAI_CHATGPT_DEPLOYMENT") if openai_host == OpenAIHost.AZURE else None,
455526
)
456527

457528
image_embeddings_service = setup_image_embeddings_service(

app/backend/prepdocslib/blobmanager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ async def upload_blob(self, file: File) -> Optional[list[str]]:
6060
blob_client = await container_client.upload_blob(blob_name, reopened_file, overwrite=True)
6161
file.url = blob_client.url
6262

63-
if self.store_page_images:
64-
if os.path.splitext(file.content.name)[1].lower() == ".pdf":
65-
return await self.upload_pdf_blob_images(service_client, container_client, file)
66-
else:
67-
logger.info("File %s is not a PDF, skipping image upload", file.content.name)
63+
#if self.store_page_images:
64+
# if os.path.splitext(file.content.name)[1].lower() == ".pdf":
65+
# return await self.upload_pdf_blob_images(service_client, container_client, file)
66+
# else:
67+
# logger.info("File %s is not a PDF, skipping image upload", file.content.name)
6868

6969
return None
7070

app/backend/prepdocslib/mediadescriber.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,20 @@ async def describe_image(self, image_bytes: bytes) -> str:
121121

122122
response = await self.openai_client.chat.completions.create(
123123
model=self.model if self.deployment is None else self.deployment,
124+
max_tokens=500,
124125
messages=[
125126
{
126127
"role": "system",
127-
"content": "You are a helpful assistant that describes images.",
128+
"content": "You are a helpful assistant that describes images from organizational documents.",
128129
},
129130
{
130131
"role": "user",
131132
"content":
132-
[{"text": "Describe this image in detail", "type": "text"},
133-
{"image_url": {"url": image_datauri}, "type": "image_url"}]
133+
[{"text": "Describe image with no more than 5 sentences. Do not speculate about anything you don't know.", "type": "text"},
134+
{"image_url": {"url": image_datauri}, "type": "image_url", "detail": "low"}]
134135
}
135136
])
136-
return response.choices[0].message.content.strip() if response.choices else ""
137+
description = response.choices[0].message.content.strip() if response.choices else ""
138+
print(description)
139+
return description
137140

app/backend/prepdocslib/pdfparser.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def __init__(
6161
endpoint: str,
6262
credential: Union[AsyncTokenCredential, AzureKeyCredential],
6363
model_id="prebuilt-layout",
64-
include_media_description: bool = False,
6564
media_description_strategy: Enum = MediaDescriptionStrategy.NONE,
6665
# If using OpenAI, this is the client to use
6766
openai_client: Union[AsyncOpenAI, None] = None,
@@ -275,6 +274,10 @@ def crop_image_from_pdf_page(
275274
pix = page.get_pixmap(matrix=pymupdf.Matrix(page_dpi / bbox_dpi, page_dpi / bbox_dpi), clip=rect)
276275

277276
img = Image.frombytes("RGB", (pix.width, pix.height), pix.samples)
277+
# print out the number of pixels
278+
print(f"Cropped image size: {img.size} pixels")
278279
bytes_io = io.BytesIO()
279280
img.save(bytes_io, format="PNG")
281+
with open(f"cropped_page_{page_number + 1}.png", "wb") as f:
282+
f.write(bytes_io.getvalue())
280283
return bytes_io.getvalue()

app/backend/prepdocslib/searchmanager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ async def create_index(self):
298298
field.name == self.field_name_embedding for field in existing_index.fields
299299
):
300300
logger.info("Adding %s field for text embeddings", self.field_name_embedding)
301+
embedding_field.stored = True
301302
existing_index.fields.append(embedding_field)
302303
if existing_index.vector_search is None:
303304
raise ValueError("Vector search is not enabled for the existing index")

app/frontend/package-lock.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)