|
12 | 12 | from azure.storage.blob.aio import ContainerClient
|
13 | 13 |
|
14 | 14 |
|
| 15 | +POLL_TIMEOUT_SECONDS = 120 |
| 16 | + |
| 17 | + |
15 | 18 | class AzureContentUnderstandingClient:
|
16 | 19 |
|
17 | 20 | PREBUILT_DOCUMENT_ANALYZER_ID: str = "prebuilt-documentAnalyzer"
|
@@ -127,43 +130,39 @@ def _get_headers(
|
127 | 130 | return headers
|
128 | 131 |
|
129 | 132 | @staticmethod
|
130 |
| - def is_supported_type_by_file_path(file_path: Path, is_pro_mode: bool=False) -> bool: |
| 133 | + def is_supported_type_by_file_ext(file_ext: str, is_pro_mode: bool=False) -> bool: |
131 | 134 | """
|
132 |
| - Checks if the given file path has a supported file type. |
| 135 | + Checks if the given file extension is supported. |
133 | 136 |
|
134 | 137 | Args:
|
135 |
| - file_path (Path): The path to the file to check. |
| 138 | + file_ext (str): The file extension to check. |
136 | 139 | is_pro_mode (bool): If True, checks against Pro mode supported file types.
|
137 | 140 |
|
138 | 141 | Returns:
|
139 | 142 | bool: True if the file type is supported, False otherwise.
|
140 | 143 | """
|
141 |
| - if not file_path.is_file(): |
142 |
| - return False |
143 |
| - file_ext = file_path.suffix.lower() |
144 | 144 | supported_types = (
|
145 | 145 | AzureContentUnderstandingClient.SUPPORTED_FILE_TYPES_PRO_MODE
|
146 | 146 | if is_pro_mode else AzureContentUnderstandingClient.SUPPORTED_FILE_TYPES
|
147 | 147 | )
|
148 |
| - return file_ext in supported_types |
149 |
| - |
| 148 | + return file_ext.lower() in supported_types |
| 149 | + |
150 | 150 | @staticmethod
|
151 |
| - def is_supported_type_by_file_ext(file_ext: str, is_pro_mode: bool=False) -> bool: |
| 151 | + def is_supported_type_by_file_path(file_path: Path, is_pro_mode: bool=False) -> bool: |
152 | 152 | """
|
153 |
| - Checks if the given file extension is supported. |
| 153 | + Checks if the given file path has a supported file type. |
154 | 154 |
|
155 | 155 | Args:
|
156 |
| - file_ext (str): The file extension to check. |
| 156 | + file_path (Path): The path to the file to check. |
157 | 157 | is_pro_mode (bool): If True, checks against Pro mode supported file types.
|
158 | 158 |
|
159 | 159 | Returns:
|
160 | 160 | bool: True if the file type is supported, False otherwise.
|
161 | 161 | """
|
162 |
| - supported_types = ( |
163 |
| - AzureContentUnderstandingClient.SUPPORTED_FILE_TYPES_PRO_MODE |
164 |
| - if is_pro_mode else AzureContentUnderstandingClient.SUPPORTED_FILE_TYPES |
165 |
| - ) |
166 |
| - return file_ext.lower() in supported_types |
| 162 | + if not file_path.is_file(): |
| 163 | + return False |
| 164 | + file_ext = file_path.suffix.lower() |
| 165 | + return AzureContentUnderstandingClient.is_supported_type_by_file_ext(file_ext, is_pro_mode) |
167 | 166 |
|
168 | 167 | def get_all_analyzers(self) -> Dict[str, Any]:
|
169 | 168 | """
|
@@ -315,10 +314,10 @@ def begin_analyze(self, analyzer_id: str, file_location: str) -> Response:
|
315 | 314 | data = {
|
316 | 315 | "inputs": [
|
317 | 316 | {
|
318 |
| - "name": f.name, |
| 317 | + "name": "_".join(f.relative_to(file_path).parts), # flatten the relative file path into a single string using underscores |
319 | 318 | "data": base64.b64encode(f.read_bytes()).decode("utf-8")
|
320 | 319 | }
|
321 |
| - for f in file_path.iterdir() |
| 320 | + for f in file_path.rglob("*") |
322 | 321 | if f.is_file() and self.is_supported_type_by_file_path(f, is_pro_mode=True)
|
323 | 322 | ]
|
324 | 323 | }
|
@@ -365,7 +364,7 @@ def get_prebuilt_document_analyze_result(self, file_location: str) -> Dict[str,
|
365 | 364 | file_location=file_location,
|
366 | 365 | )
|
367 | 366 |
|
368 |
| - return self.poll_result(response, timeout_seconds=360) |
| 367 | + return self.poll_result(response, timeout_seconds=POLL_TIMEOUT_SECONDS) |
369 | 368 |
|
370 | 369 | async def _upload_file_to_blob(
|
371 | 370 | self, container_client: ContainerClient, file_path: str, target_blob_path: str
|
@@ -397,40 +396,40 @@ async def generate_knowledge_base_on_blob(
|
397 | 396 | storage_container_path_prefix: str,
|
398 | 397 | skip_analyze: bool = False,
|
399 | 398 | ) -> None:
|
400 |
| - container_client = ContainerClient.from_container_url(storage_container_sas_url) |
| 399 | + if not storage_container_path_prefix.endswith("/"): |
| 400 | + storage_container_path_prefix += "/" |
401 | 401 | resources = []
|
402 |
| - for dirpath, _, filenames in os.walk(reference_docs_folder): |
403 |
| - for filename in filenames: |
404 |
| - filename_no_ext, file_ext = os.path.splitext(filename) |
405 |
| - if self.is_supported_type_by_file_ext(file_ext, is_pro_mode=True): |
406 |
| - file_path = os.path.join(dirpath, filename) |
407 |
| - result_file_name = filename_no_ext + self.OCR_RESULT_FILE_SUFFIX |
408 |
| - result_file_blob_path = storage_container_path_prefix + result_file_name |
409 |
| - # Get and upload result.json |
410 |
| - if not skip_analyze: |
411 |
| - self._logger.info(f"Analyzing result for {filename}") |
412 |
| - try: |
413 |
| - analyze_result = self.get_prebuilt_document_analyze_result(file_path) |
414 |
| - except Exception as e: |
415 |
| - self._logger.error(f"Error of getting analyze result of {filename}: {e}") |
416 |
| - continue |
417 |
| - await self._upload_json_to_blob(container_client, analyze_result, result_file_blob_path) |
418 |
| - else: |
419 |
| - self._logger.info(f"Using existing result.json for {filename}") |
420 |
| - result_file_path = os.path.join(dirpath, result_file_name) |
421 |
| - if not os.path.exists(result_file_path): |
422 |
| - self._logger.warning(f"Result file {result_file_name} does not exist, skipping.") |
423 |
| - continue |
424 |
| - await self._upload_file_to_blob(container_client, result_file_path, result_file_blob_path) |
425 |
| - # Upload the original file |
426 |
| - file_blob_path = storage_container_path_prefix + filename |
427 |
| - await self._upload_file_to_blob(container_client, file_path, file_blob_path) |
428 |
| - resources.append({"file": filename, "resultFile": result_file_name}) |
429 |
| - # Upload sources.jsonl |
430 |
| - await self.upload_jsonl_to_blob( |
431 |
| - container_client, resources, storage_container_path_prefix + self.KNOWLEDGE_SOURCE_LIST_FILE_NAME) |
432 |
| - await container_client.close() |
433 |
| - |
| 402 | + async with ContainerClient.from_container_url(storage_container_sas_url) as container_client: |
| 403 | + for dirpath, _, filenames in os.walk(reference_docs_folder): |
| 404 | + for filename in filenames: |
| 405 | + filename_no_ext, file_ext = os.path.splitext(filename) |
| 406 | + if self.is_supported_type_by_file_ext(file_ext, is_pro_mode=True): |
| 407 | + file_path = os.path.join(dirpath, filename) |
| 408 | + result_file_name = filename_no_ext + self.OCR_RESULT_FILE_SUFFIX |
| 409 | + result_file_blob_path = storage_container_path_prefix + result_file_name |
| 410 | + # Get and upload result.json |
| 411 | + if not skip_analyze: |
| 412 | + self._logger.info(f"Analyzing result for {filename}") |
| 413 | + try: |
| 414 | + analyze_result = self.get_prebuilt_document_analyze_result(file_path) |
| 415 | + except Exception as e: |
| 416 | + self._logger.error(f"Error of getting analyze result of {filename}: {e}") |
| 417 | + continue |
| 418 | + await self._upload_json_to_blob(container_client, analyze_result, result_file_blob_path) |
| 419 | + else: |
| 420 | + self._logger.info(f"Using existing result.json for {filename}") |
| 421 | + result_file_path = os.path.join(dirpath, result_file_name) |
| 422 | + if not os.path.exists(result_file_path): |
| 423 | + self._logger.warning(f"Result file {result_file_name} does not exist, skipping.") |
| 424 | + continue |
| 425 | + await self._upload_file_to_blob(container_client, result_file_path, result_file_blob_path) |
| 426 | + # Upload the original file |
| 427 | + file_blob_path = storage_container_path_prefix + filename |
| 428 | + await self._upload_file_to_blob(container_client, file_path, file_blob_path) |
| 429 | + resources.append({"file": filename, "resultFile": result_file_name}) |
| 430 | + # Upload sources.jsonl |
| 431 | + await self.upload_jsonl_to_blob( |
| 432 | + container_client, resources, storage_container_path_prefix + self.KNOWLEDGE_SOURCE_LIST_FILE_NAME) |
434 | 433 |
|
435 | 434 | def get_image_from_analyze_operation(
|
436 | 435 | self, analyze_response: Response, image_id: str
|
@@ -552,7 +551,7 @@ def begin_classify(self, classifier_id: str, file_location: str) -> Response:
|
552 | 551 | def poll_result(
|
553 | 552 | self,
|
554 | 553 | response: Response,
|
555 |
| - timeout_seconds: int = 120, |
| 554 | + timeout_seconds: int = POLL_TIMEOUT_SECONDS, |
556 | 555 | polling_interval_seconds: int = 2,
|
557 | 556 | ) -> Dict[str, Any]:
|
558 | 557 | """
|
|
0 commit comments