Skip to content

Commit ef36e19

Browse files
revise for pro mode
1 parent fc62631 commit ef36e19

File tree

2 files changed

+150
-10
lines changed

2 files changed

+150
-10
lines changed

python/content_understanding_client.py

Lines changed: 148 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,53 @@
1-
import requests
2-
from requests.models import Response
3-
import logging
1+
import base64
42
import json
3+
import logging
4+
import os
5+
import requests
56
import time
7+
8+
from requests.models import Response
9+
from typing import Any
610
from pathlib import Path
711

12+
from azure.storage.blob.aio import ContainerClient
13+
814

915
class AzureContentUnderstandingClient:
16+
17+
PREBUILT_DOCUMENT_ANALYZER_ID = "prebuilt-documentAnalyzer"
18+
RESULT_SUFFIX = ".result.json"
19+
SOURCES_JSONL = "sources.jsonl"
20+
21+
# https://learn.microsoft.com/en-us/azure/ai-services/content-understanding/service-limits#document-and-text
22+
SUPPORTED_FILE_TYPES = [
23+
".pdf",
24+
".tiff",
25+
".jpg",
26+
".jpeg",
27+
".png",
28+
".bmp",
29+
".heif",
30+
".docx",
31+
".xlsx",
32+
".pptx",
33+
".txt",
34+
".html",
35+
".md",
36+
".eml",
37+
".msg",
38+
".xml",
39+
]
40+
41+
SUPPORTED_FILE_TYPES_PRO_MODE = [
42+
".pdf",
43+
".tiff",
44+
".jpg",
45+
".jpeg",
46+
".png",
47+
".bmp",
48+
".heif",
49+
]
50+
1051
def __init__(
1152
self,
1253
endpoint: str,
@@ -53,12 +94,12 @@ def _get_training_data_config(
5394
def _get_pro_mode_reference_docs_config(
5495
self, storage_container_sas_url, storage_container_path_prefix
5596
):
56-
return {
97+
return [{
5798
"kind": "reference",
5899
"containerUrl": storage_container_sas_url,
59100
"prefix": storage_container_path_prefix,
60-
"fileListPath": "sources.jsonl",
61-
}
101+
"fileListPath": self.SOURCES_JSONL,
102+
}]
62103

63104
def _get_headers(self, subscription_key, api_token, x_ms_useragent):
64105
"""Returns the headers for the HTTP requests.
@@ -76,6 +117,28 @@ def _get_headers(self, subscription_key, api_token, x_ms_useragent):
76117
)
77118
headers["x-ms-useragent"] = x_ms_useragent
78119
return headers
120+
121+
@staticmethod
122+
def is_supported_file_type(file_path: Path, is_pro_mode: bool=False) -> bool:
123+
"""
124+
Checks if the given file path has a supported file type.
125+
126+
Args:
127+
file_path (Path): The path to the file to check.
128+
is_pro_mode (bool): If True, checks against pro mode supported file types.
129+
130+
Returns:
131+
bool: True if the file type is supported, False otherwise.
132+
"""
133+
if not file_path.is_file():
134+
return False
135+
file_ext = file_path.suffix.lower()
136+
supported_types = (
137+
AzureContentUnderstandingClient.SUPPORTED_FILE_TYPES_PRO_MODE
138+
if is_pro_mode else AzureContentUnderstandingClient.SUPPORTED_FILE_TYPES
139+
)
140+
return file_ext in supported_types
141+
79142

80143
def get_all_analyzers(self):
81144
"""
@@ -220,10 +283,27 @@ def begin_analyze(self, analyzer_id: str, file_location: str):
220283
HTTPError: If the HTTP request returned an unsuccessful status code.
221284
"""
222285
data = None
223-
if Path(file_location).exists():
224-
with open(file_location, "rb") as file:
225-
data = file.read()
226-
headers = {"Content-Type": "application/octet-stream"}
286+
file_path = Path(file_location)
287+
if file_path.exists():
288+
if file_path.is_dir():
289+
# Only pro mode supports multiple input files
290+
data = {
291+
"inputs": [
292+
{
293+
"name": f.name,
294+
"data": base64.b64encode(f.read_bytes()).decode("utf-8")
295+
}
296+
for f in file_path.iterdir()
297+
if f.is_file() and self.is_supported_file_type(f, is_pro_mode=True)
298+
]
299+
}
300+
headers = {"Content-Type": "application/json"}
301+
elif file_path.is_file() and self.is_supported_file_type(file_path):
302+
with open(file_location, "rb") as file:
303+
data = file.read()
304+
headers = {"Content-Type": "application/octet-stream"}
305+
else:
306+
raise ValueError("File location must be a valid and supported file or directory path.")
227307
elif "https://" in file_location or "http://" in file_location:
228308
data = {"url": file_location}
229309
headers = {"Content-Type": "application/json"}
@@ -253,6 +333,64 @@ def begin_analyze(self, analyzer_id: str, file_location: str):
253333
f"Analyzing file {file_location} with analyzer: {analyzer_id}"
254334
)
255335
return response
336+
337+
def get_analyze_result(self, file_location: str):
338+
response = self.begin_analyze(
339+
analyzer_id=self.PREBUILT_DOCUMENT_ANALYZER_ID,
340+
file_location=file_location,
341+
)
342+
343+
return self.poll_result(response, timeout_seconds=360)
344+
345+
async def _upload_file_to_blob(self, container_client: ContainerClient, file_path: str, target_blob_path: str):
346+
with open(file_path, "rb") as data:
347+
await container_client.upload_blob(name=target_blob_path, data=data, overwrite=True)
348+
self._logger.info(f"Uploaded file to {target_blob_path}")
349+
350+
async def _upload_json_to_blob(self, container_client: ContainerClient, data: dict, target_blob_path: str):
351+
json_str = json.dumps(data, indent=4)
352+
json_bytes = json_str.encode('utf-8')
353+
await container_client.upload_blob(name=target_blob_path, data=json_bytes, overwrite=True)
354+
self._logger.info(f"Uploaded json to {target_blob_path}")
355+
356+
async def upload_jsonl_to_blob(self, container_client: ContainerClient, data_list: list[dict], target_blob_path: str):
357+
jsonl_string = "\n".join(json.dumps(record) for record in data_list)
358+
jsonl_bytes = jsonl_string.encode("utf-8")
359+
await container_client.upload_blob(name=target_blob_path, data=jsonl_bytes, overwrite=True)
360+
self._logger.info(f"Uploaded jsonl to blob '{target_blob_path}'")
361+
362+
async def generate_knowledge_base_on_blob(
363+
self,
364+
referemce_docs_folder: str,
365+
storage_container_sas_url: str,
366+
storage_container_path_prefix: str,
367+
has_result_json: bool = False,
368+
):
369+
container_client = ContainerClient.from_container_url(storage_container_sas_url)
370+
if not has_result_json:
371+
self._logger.info("Generating knowledge base files...")
372+
resources = []
373+
for dirpath, _, filenames in os.walk(referemce_docs_folder):
374+
for filename in filenames:
375+
filename_no_ext, file_ext = os.path.splitext(filename)
376+
if file_ext.lower() in self.SUPPORTED_FILE_TYPES_PRO_MODE:
377+
file_path = os.path.join(dirpath, filename)
378+
file_blob_path = storage_container_path_prefix + filename
379+
self._logger.info(f"Generating result for {file_path}")
380+
try:
381+
analyze_result = self.get_analyze_result(file_path)
382+
result_file_name = filename_no_ext + self.RESULT_SUFFIX
383+
result_file_blob_path = storage_container_path_prefix + result_file_name
384+
await self._upload_json_to_blob(container_client, analyze_result, result_file_blob_path)
385+
await self._upload_file_to_blob(container_client, file_path, file_blob_path)
386+
resources.append({"file": filename, "resultFile": result_file_name})
387+
except Exception as e:
388+
self._logger.error(f"Error of Generating knowledge base of {filename}: {e}")
389+
continue
390+
await self.upload_jsonl_to_blob(container_client, resources, storage_container_path_prefix + self.SOURCES_JSONL)
391+
await container_client.close()
392+
# TODO: the logic for existing result.json
393+
256394

257395
def get_image_from_analyze_operation(
258396
self, analyze_response: Response, image_id: str

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
aiohttp
12
azure-identity
3+
azure-storage-blob
24
python-dotenv
35
requests
46
Pillow

0 commit comments

Comments
 (0)