Skip to content

Commit bab4350

Browse files
committed
Mypy fixes
1 parent c803bfa commit bab4350

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

app/backend/app.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ async def setup_clients():
397397
AZURE_SEARCH_INDEX = os.environ["AZURE_SEARCH_INDEX"]
398398
AZURE_SEARCH_AGENT = os.getenv("AZURE_SEARCH_AGENT", "")
399399
# Shared by all OpenAI deployments
400-
OPENAI_HOST = os.getenv("OPENAI_HOST", "azure")
400+
OPENAI_HOST = OpenAIHost(os.getenv("OPENAI_HOST", "azure"))
401401
OPENAI_CHATGPT_MODEL = os.environ["AZURE_OPENAI_CHATGPT_MODEL"]
402402
AZURE_OPENAI_SEARCHAGENT_MODEL = os.getenv("AZURE_OPENAI_SEARCHAGENT_MODEL")
403403
AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT = os.getenv("AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT")
@@ -407,9 +407,13 @@ async def setup_clients():
407407
# Used with Azure OpenAI deployments
408408
AZURE_OPENAI_SERVICE = os.getenv("AZURE_OPENAI_SERVICE")
409409
AZURE_OPENAI_CHATGPT_DEPLOYMENT = (
410-
os.getenv("AZURE_OPENAI_CHATGPT_DEPLOYMENT") if OPENAI_HOST.startswith("azure") else None
410+
os.getenv("AZURE_OPENAI_CHATGPT_DEPLOYMENT")
411+
if OPENAI_HOST in [OpenAIHost.AZURE, OpenAIHost.AZURE_CUSTOM]
412+
else None
413+
)
414+
AZURE_OPENAI_EMB_DEPLOYMENT = (
415+
os.getenv("AZURE_OPENAI_EMB_DEPLOYMENT") if OPENAI_HOST in [OpenAIHost.AZURE, OpenAIHost.AZURE_CUSTOM] else None
411416
)
412-
AZURE_OPENAI_EMB_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMB_DEPLOYMENT") if OPENAI_HOST.startswith("azure") else None
413417
AZURE_OPENAI_CUSTOM_URL = os.getenv("AZURE_OPENAI_CUSTOM_URL")
414418
# https://learn.microsoft.com/azure/ai-services/openai/api-version-deprecation#latest-ga-api-release
415419
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") or "2024-10-21"

app/backend/approaches/approach.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,8 @@ class ExtraArgs(TypedDict, total=False):
449449
return VectorizedQuery(vector=query_vector, k_nearest_neighbors=50, fields=self.embedding_field)
450450

451451
async def compute_multimodal_embedding(self, q: str):
452+
if not self.image_embeddings_client:
453+
raise ValueError("Approach is missing an image embeddings client for multimodal queries")
452454
multimodal_query_vector = await self.image_embeddings_client.create_embedding_for_text(q)
453455
return VectorizedQuery(vector=multimodal_query_vector, k_nearest_neighbors=50, fields="images/embedding")
454456

app/backend/prepdocslib/blobmanager.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
from azure.storage.blob.aio import (
1313
StorageStreamDownloader as BlobStorageStreamDownloader,
1414
)
15-
from azure.storage.filedatalake import DataLakeDirectoryClient
16-
from azure.storage.filedatalake import (
15+
from azure.storage.filedatalake.aio import DataLakeDirectoryClient, FileSystemClient
16+
from azure.storage.filedatalake.aio import (
1717
StorageStreamDownloader as AdlsBlobStorageStreamDownloader,
1818
)
19-
from azure.storage.filedatalake.aio import FileSystemClient
2019
from PIL import Image, ImageDraw, ImageFont
2120

2221
from .listfilestrategy import File
@@ -254,16 +253,17 @@ async def upload_document_image(
254253

255254
async def download_blob(
256255
self, blob_path: str, user_oid: Optional[str] = None, as_bytes: bool = False
257-
) -> Optional[Union[AdlsBlobStorageStreamDownloader, BlobStorageStreamDownloader]]:
256+
) -> Optional[Union[AdlsBlobStorageStreamDownloader, bytes]]:
258257
"""
259258
Downloads a blob from Azure Data Lake Storage.
260259
261260
Args:
262261
blob_path: The path to the blob in the format {user_oid}/{document_name}/images/{image_name}
263262
user_oid: The user's object ID
263+
as_bytes: If True, returns the blob as bytes, otherwise returns a stream downloader
264264
265265
Returns:
266-
Optional[Union[AdlsBlobStorageStreamDownloader, BlobStorageStreamDownloader]]: A stream downloader for the blob, or None if not found
266+
Optional[Union[AdlsBlobStorageStreamDownloader, bytes]]: A stream downloader for the blob, or bytes if as_bytes=True, or None if not found
267267
"""
268268
if user_oid is None:
269269
logger.warning("user_oid must be provided for Data Lake Storage operations.")
@@ -322,8 +322,8 @@ async def remove_blob(self, filename: str, user_oid: str) -> None:
322322
await file_client.delete_file()
323323

324324
# Try to delete any associated image directories
325+
image_directory_path = self._get_image_directory_path(filename, user_oid)
325326
try:
326-
image_directory_path = self._get_image_directory_path(filename, user_oid)
327327
image_directory_client = await self._ensure_directory(
328328
directory_path=image_directory_path, user_oid=user_oid
329329
)
@@ -408,7 +408,7 @@ def get_managedidentity_connectionstring(self):
408408
raise ValueError("Account, resource group, and subscription ID must be set to generate connection string.")
409409
return f"ResourceId=/subscriptions/{self.subscription_id}/resourceGroups/{self.resource_group}/providers/Microsoft.Storage/storageAccounts/{self.account};"
410410

411-
async def upload_blob(self, file: File) -> Optional[list[str]]:
411+
async def upload_blob(self, file: File) -> str:
412412
container_client = self.blob_service_client.get_container_client(self.container)
413413
if not await container_client.exists():
414414
await container_client.create_container()
@@ -421,6 +421,8 @@ async def upload_blob(self, file: File) -> Optional[list[str]]:
421421
blob_client = await container_client.upload_blob(blob_name, reopened_file, overwrite=True)
422422
file.url = blob_client.url
423423

424+
return unquote(file.url)
425+
424426
async def upload_document_image(
425427
self,
426428
document_filename: str,

app/backend/prepdocslib/filestrategy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from azure.core.credentials import AzureKeyCredential
55

6-
from .blobmanager import BlobManager
6+
from .blobmanager import AdlsBlobManager, BaseBlobManager, BlobManager
77
from .embeddings import ImageEmbeddings, OpenAIEmbeddings
88
from .fileprocessor import FileProcessor
99
from .listfilestrategy import File, ListFileStrategy
@@ -18,7 +18,7 @@ async def parse_file(
1818
file: File,
1919
file_processors: dict[str, FileProcessor],
2020
category: Optional[str] = None,
21-
blob_manager: Optional[BlobManager] = None,
21+
blob_manager: Optional[BaseBlobManager] = None,
2222
image_embeddings_client: Optional[ImageEmbeddings] = None,
2323
user_oid: Optional[str] = None,
2424
) -> list[Section]:
@@ -145,10 +145,10 @@ def __init__(
145145
self,
146146
search_info: SearchInfo,
147147
file_processors: dict[str, FileProcessor],
148+
blob_manager: AdlsBlobManager,
149+
search_field_name_embedding: Optional[str] = None,
148150
embeddings: Optional[OpenAIEmbeddings] = None,
149151
image_embeddings: Optional[ImageEmbeddings] = None,
150-
search_field_name_embedding: Optional[str] = None,
151-
blob_manager: Optional[BlobManager] = None,
152152
):
153153
self.file_processors = file_processors
154154
self.embeddings = embeddings

0 commit comments

Comments
 (0)