Skip to content

Commit 0681755

Browse files
committed
Make mypy happy
1 parent 9973a77 commit 0681755

File tree

3 files changed

+60
-49
lines changed

3 files changed

+60
-49
lines changed

app/backend/prepdocslib/filestrategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from typing import List, Optional
33

44
from .blobmanager import BlobManager
5-
from .cu_image import ContentUnderstandingManager
65
from .embeddings import ImageEmbeddings, OpenAIEmbeddings
76
from .fileprocessor import FileProcessor
87
from .listfilestrategy import File, ListFileStrategy
8+
from .mediadescriber import ContentUnderstandingDescriber
99
from .searchmanager import SearchManager, Section
1010
from .strategy import DocumentAction, SearchInfo, Strategy
1111

@@ -79,7 +79,7 @@ async def setup(self):
7979
await search_manager.create_index()
8080

8181
if self.use_content_understanding:
82-
cu_manager = ContentUnderstandingManager(self.content_understanding_endpoint, self.search_info.credential)
82+
cu_manager = ContentUnderstandingDescriber(self.content_understanding_endpoint, self.search_info.credential)
8383
await cu_manager.create_analyzer()
8484

8585
async def run(self):

app/backend/prepdocslib/cu_image.py renamed to app/backend/prepdocslib/mediadescriber.py

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Union
2+
from abc import ABC
33

44
import aiohttp
55
from azure.core.credentials_async import AsyncTokenCredential
@@ -9,39 +9,36 @@
99

1010
logger = logging.getLogger("scripts")
1111

12-
CU_API_VERSION = "2024-12-01-preview"
13-
14-
PATH_ANALYZER_MANAGEMENT = "/analyzers/{analyzerId}"
15-
PATH_ANALYZER_MANAGEMENT_OPERATION = "/analyzers/{analyzerId}/operations/{operationId}"
16-
17-
# Define Analyzer inference paths
18-
PATH_ANALYZER_INFERENCE = "/analyzers/{analyzerId}:analyze"
19-
PATH_ANALYZER_INFERENCE_GET_IMAGE = "/analyzers/{analyzerId}/results/{operationId}/images/{imageId}"
20-
21-
analyzer_name = "image_analyzer"
22-
image_schema = {
23-
"analyzerId": analyzer_name,
24-
"name": "Image understanding",
25-
"description": "Extract detailed structured information from images extracted from documents.",
26-
"baseAnalyzerId": "prebuilt-image",
27-
"scenario": "image",
28-
"config": {"returnDetails": False},
29-
"fieldSchema": {
30-
"name": "ImageInformation",
31-
"descriptions": "Description of image.",
32-
"fields": {
33-
"Description": {
34-
"type": "string",
35-
"description": "Description of the image. If the image has a title, start with the title. Include a 2-sentence summary. If the image is a chart, diagram, or table, include the underlying data in an HTML table tag, with accurate numbers. If the image is a chart, describe any axis or legends. The only allowed HTML tags are the table/thead/tr/td/tbody tags.",
36-
},
37-
},
38-
},
39-
}
4012

13+
class MediaDescriber(ABC):
4114

42-
class ContentUnderstandingManager:
15+
async def describe_image(self, image_bytes) -> str:
16+
raise NotImplementedError
17+
18+
19+
class ContentUnderstandingDescriber:
20+
CU_API_VERSION = "2024-12-01-preview"
21+
22+
analyzer_schema = {
23+
"analyzerId": "image_analyzer",
24+
"name": "Image understanding",
25+
"description": "Extract detailed structured information from images extracted from documents.",
26+
"baseAnalyzerId": "prebuilt-image",
27+
"scenario": "image",
28+
"config": {"returnDetails": False},
29+
"fieldSchema": {
30+
"name": "ImageInformation",
31+
"descriptions": "Description of image.",
32+
"fields": {
33+
"Description": {
34+
"type": "string",
35+
"description": "Description of the image. If the image has a title, start with the title. Include a 2-sentence summary. If the image is a chart, diagram, or table, include the underlying data in an HTML table tag, with accurate numbers. If the image is a chart, describe any axis or legends. The only allowed HTML tags are the table/thead/tr/td/tbody tags.",
36+
},
37+
},
38+
},
39+
}
4340

44-
def __init__(self, endpoint: str, credential: Union[AsyncTokenCredential, str]):
41+
def __init__(self, endpoint: str, credential: AsyncTokenCredential):
4542
self.endpoint = endpoint
4643
self.credential = credential
4744

@@ -61,16 +58,18 @@ async def poll():
6158
return await poll()
6259

6360
async def create_analyzer(self):
64-
logger.info("Creating analyzer '%s'...", image_schema["analyzerId"])
61+
logger.info("Creating analyzer '%s'...", self.analyzer_schema["analyzerId"])
6562

6663
token_provider = get_bearer_token_provider(self.credential, "https://cognitiveservices.azure.com/.default")
6764
token = await token_provider()
6865
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
69-
params = {"api-version": CU_API_VERSION}
70-
analyzer_id = image_schema["analyzerId"]
66+
params = {"api-version": self.CU_API_VERSION}
67+
analyzer_id = self.analyzer_schema["analyzerId"]
7168
cu_endpoint = f"{self.endpoint}/contentunderstanding/analyzers/{analyzer_id}"
7269
async with aiohttp.ClientSession() as session:
73-
async with session.put(url=cu_endpoint, params=params, headers=headers, json=image_schema) as response:
70+
async with session.put(
71+
url=cu_endpoint, params=params, headers=headers, json=self.analyzer_schema
72+
) as response:
7473
if response.status == 409:
7574
logger.info("Analyzer '%s' already exists.", analyzer_id)
7675
return
@@ -90,8 +89,8 @@ async def describe_image(self, image_bytes) -> str:
9089
async with aiohttp.ClientSession() as session:
9190
token = await self.credential.get_token("https://cognitiveservices.azure.com/.default")
9291
headers = {"Authorization": "Bearer " + token.token}
93-
params = {"api-version": CU_API_VERSION}
94-
92+
params = {"api-version": self.CU_API_VERSION}
93+
analyzer_name = self.analyzer_schema["analyzerId"]
9594
async with session.post(
9695
url=f"{self.endpoint}/contentunderstanding/analyzers/{analyzer_name}:analyze",
9796
params=params,

app/backend/prepdocslib/pdfparser.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from PIL import Image
1919
from pypdf import PdfReader
2020

21-
from .cu_image import ContentUnderstandingManager
21+
from .mediadescriber import ContentUnderstandingDescriber
2222
from .page import Page
2323
from .parser import Parser
2424

@@ -55,7 +55,7 @@ def __init__(
5555
credential: Union[AsyncTokenCredential, AzureKeyCredential],
5656
model_id="prebuilt-layout",
5757
use_content_understanding=True,
58-
content_understanding_endpoint: str = None,
58+
content_understanding_endpoint: Union[str, None] = None,
5959
):
6060
self.model_id = model_id
6161
self.endpoint = endpoint
@@ -66,13 +66,19 @@ def __init__(
6666
async def parse(self, content: IO) -> AsyncGenerator[Page, None]:
6767
logger.info("Extracting text from '%s' using Azure Document Intelligence", content.name)
6868

69-
cu_manager = ContentUnderstandingManager(self.content_understanding_endpoint, self.credential)
7069
async with DocumentIntelligenceClient(
7170
endpoint=self.endpoint, credential=self.credential
7271
) as document_intelligence_client:
7372
# turn content into bytes
7473
content_bytes = content.read()
7574
if self.use_content_understanding:
75+
if self.content_understanding_endpoint is None:
76+
raise ValueError("content_understanding_endpoint should not be None")
77+
if isinstance(self.credential, AzureKeyCredential):
78+
raise ValueError(
79+
"AzureKeyCredential is not supported for Content Understanding, use keyless auth instead"
80+
)
81+
cu_describer = ContentUnderstandingDescriber(self.content_understanding_endpoint, self.credential)
7682
poller = await document_intelligence_client.begin_analyze_document(
7783
model_id="prebuilt-layout",
7884
analyze_request=AnalyzeDocumentRequest(bytes_source=content_bytes),
@@ -111,7 +117,7 @@ class ObjectType(Enum):
111117
# mark all positions of the table spans in the page
112118
page_offset = page.spans[0].offset
113119
page_length = page.spans[0].length
114-
mask_chars = [(ObjectType.NONE, None)] * page_length
120+
mask_chars: list[tuple[ObjectType, Union[int, None]]] = [(ObjectType.NONE, None)] * page_length
115121
for table_idx, table in enumerate(tables_on_page):
116122
for span in table.spans:
117123
# replace all table spans with "table_id" in table_chars array
@@ -132,16 +138,20 @@ class ObjectType(Enum):
132138
added_objects = set() # set of object types todo mypy
133139
for idx, mask_char in enumerate(mask_chars):
134140
object_type, object_idx = mask_char
141+
if object_idx is None:
142+
raise ValueError("object_idx should not be None")
135143
if object_type == ObjectType.NONE:
136144
page_text += form_recognizer_results.content[page_offset + idx]
137145
elif object_type == ObjectType.TABLE:
138146
if mask_char not in added_objects:
139147
page_text += DocumentAnalysisParser.table_to_html(tables_on_page[object_idx])
140148
added_objects.add(mask_char)
141149
elif object_type == ObjectType.FIGURE:
150+
if cu_describer is None:
151+
raise ValueError("cu_describer should not be None, unable to describe figure")
142152
if mask_char not in added_objects:
143153
figure_html = await DocumentAnalysisParser.figure_to_html(
144-
doc_for_pymupdf, cu_manager, figures_on_page[object_idx]
154+
doc_for_pymupdf, cu_describer, figures_on_page[object_idx]
145155
)
146156
page_text += figure_html
147157
added_objects.add(mask_char)
@@ -163,9 +173,12 @@ class ObjectType(Enum):
163173

164174
@staticmethod
165175
async def figure_to_html(
166-
doc: pymupdf.Document, cu_manager: ContentUnderstandingManager, figure: DocumentFigure
176+
doc: pymupdf.Document, cu_describer: ContentUnderstandingDescriber, figure: DocumentFigure
167177
) -> str:
168-
logger.info("Describing figure '%s'", figure.id)
178+
figure_title = (figure.caption and figure.caption.content) or ""
179+
logger.info("Describing figure '%s' with title", figure.id, figure_title)
180+
if not figure.bounding_regions:
181+
return f"<figure><figcaption>{figure_title}</figcaption></figure>"
169182
for region in figure.bounding_regions:
170183
# To learn more about bounding regions, see https://aka.ms/bounding-region
171184
bounding_box = (
@@ -176,8 +189,7 @@ async def figure_to_html(
176189
)
177190
page_number = figure.bounding_regions[0]["pageNumber"] # 1-indexed
178191
cropped_img = DocumentAnalysisParser.crop_image_from_pdf_page(doc, page_number - 1, bounding_box)
179-
figure_description = await cu_manager.describe_image(cropped_img)
180-
figure_title = (figure.caption and figure.caption.content) or ""
192+
figure_description = await cu_describer.describe_image(cropped_img)
181193
return f"<figure><figcaption>{figure_title}<br>{figure_description}</figcaption></figure>"
182194

183195
@staticmethod
@@ -221,7 +233,7 @@ def crop_image_from_pdf_page(doc: pymupdf.Document, page_number, bounding_box) -
221233
# The matrix is used to convert between these 2 units
222234
pix = page.get_pixmap(matrix=pymupdf.Matrix(300 / 72, 300 / 72), clip=rect)
223235

224-
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
236+
img = Image.frombytes("RGB", (pix.width, pix.height), pix.samples)
225237
bytes_io = io.BytesIO()
226238
img.save(bytes_io, format="PNG")
227239
return bytes_io.getvalue()

0 commit comments

Comments
 (0)