6
6
import time
7
7
8
8
from requests .models import Response
9
- from typing import Any
9
+ from typing import Any , Callable , Optional , Dict , List , Union
10
10
from pathlib import Path
11
11
12
12
from azure .storage .blob .aio import ContainerClient
13
13
14
14
15
15
class AzureContentUnderstandingClient :
16
16
17
- PREBUILT_DOCUMENT_ANALYZER_ID = "prebuilt-documentAnalyzer"
18
- RESULT_SUFFIX = ".result.json"
19
- SOURCES_JSONL = "sources.jsonl"
17
+ PREBUILT_DOCUMENT_ANALYZER_ID : str = "prebuilt-documentAnalyzer"
18
+ RESULT_SUFFIX : str = ".result.json"
19
+ SOURCES_JSONL : str = "sources.jsonl"
20
20
21
21
# https://learn.microsoft.com/en-us/azure/ai-services/content-understanding/service-limits#document-and-text
22
- SUPPORTED_FILE_TYPES = [
22
+ SUPPORTED_FILE_TYPES : List [ str ] = [
23
23
".pdf" ,
24
24
".tiff" ,
25
25
".jpg" ,
@@ -38,7 +38,7 @@ class AzureContentUnderstandingClient:
38
38
".xml" ,
39
39
]
40
40
41
- SUPPORTED_FILE_TYPES_PRO_MODE = [
41
+ SUPPORTED_FILE_TYPES_PRO_MODE : List [ str ] = [
42
42
".pdf" ,
43
43
".tiff" ,
44
44
".jpg" ,
@@ -73,41 +73,43 @@ def __init__(
73
73
74
74
self ._headers = self ._get_headers (subscription_key , token , x_ms_useragent )
75
75
76
- def _get_analyzer_url (self , endpoint , api_version , analyzer_id ) :
76
+ def _get_analyzer_url (self , endpoint : str , api_version : str , analyzer_id : str ) -> str :
77
77
return f"{ endpoint } /contentunderstanding/analyzers/{ analyzer_id } ?api-version={ api_version } " # noqa
78
78
79
- def _get_analyzer_list_url (self , endpoint , api_version ) :
79
+ def _get_analyzer_list_url (self , endpoint : str , api_version : str ) -> str :
80
80
return f"{ endpoint } /contentunderstanding/analyzers?api-version={ api_version } "
81
81
82
- def _get_analyze_url (self , endpoint , api_version , analyzer_id ) :
82
+ def _get_analyze_url (self , endpoint : str , api_version : str , analyzer_id : str ) -> str :
83
83
return f"{ endpoint } /contentunderstanding/analyzers/{ analyzer_id } :analyze?api-version={ api_version } " # noqa
84
84
85
85
def _get_training_data_config (
86
- self , storage_container_sas_url , storage_container_path_prefix
87
- ):
86
+ self , storage_container_sas_url : str , storage_container_path_prefix : str
87
+ ) -> Dict [ str , str ] :
88
88
return {
89
89
"containerUrl" : storage_container_sas_url ,
90
90
"kind" : "blob" ,
91
91
"prefix" : storage_container_path_prefix ,
92
92
}
93
93
94
94
def _get_pro_mode_reference_docs_config (
95
- self , storage_container_sas_url , storage_container_path_prefix
96
- ):
95
+ self , storage_container_sas_url : str , storage_container_path_prefix : str
96
+ ) -> List [ Dict [ str , str ]] :
97
97
return [{
98
98
"kind" : "reference" ,
99
99
"containerUrl" : storage_container_sas_url ,
100
100
"prefix" : storage_container_path_prefix ,
101
101
"fileListPath" : self .SOURCES_JSONL ,
102
102
}]
103
103
104
- def _get_classifier_url (self , endpoint , api_version , classifier_id ) :
104
+ def _get_classifier_url (self , endpoint : str , api_version : str , classifier_id : str ) -> str :
105
105
return f"{ endpoint } /contentunderstanding/classifiers/{ classifier_id } ?api-version={ api_version } "
106
106
107
- def _get_classify_url (self , endpoint , api_version , classifier_id ) :
107
+ def _get_classify_url (self , endpoint : str , api_version : str , classifier_id : str ) -> str :
108
108
return f"{ endpoint } /contentunderstanding/classifiers/{ classifier_id } :classify?api-version={ api_version } "
109
109
110
- def _get_headers (self , subscription_key , api_token , x_ms_useragent ):
110
+ def _get_headers (
111
+ self , subscription_key : str , api_token : str , x_ms_useragent : str
112
+ ) -> Dict [str , str ]:
111
113
"""Returns the headers for the HTTP requests.
112
114
Args:
113
115
subscription_key (str): The subscription key for the service.
@@ -163,7 +165,7 @@ def is_supported_type_by_file_ext(file_ext: str, is_pro_mode: bool=False) -> boo
163
165
)
164
166
return file_ext .lower () in supported_types
165
167
166
- def get_all_analyzers (self ):
168
+ def get_all_analyzers (self ) -> Dict [ str , Any ] :
167
169
"""
168
170
Retrieves a list of all available analyzers from the content understanding service.
169
171
@@ -184,7 +186,7 @@ def get_all_analyzers(self):
184
186
response .raise_for_status ()
185
187
return response .json ()
186
188
187
- def get_analyzer_detail_by_id (self , analyzer_id ) :
189
+ def get_analyzer_detail_by_id (self , analyzer_id : str ) -> Dict [ str , Any ] :
188
190
"""
189
191
Retrieves a specific analyzer detail through analyzerid from the content understanding service.
190
192
This method sends a GET request to the service endpoint to get the analyzer detail.
@@ -214,7 +216,7 @@ def begin_create_analyzer(
214
216
training_storage_container_path_prefix : str = "" ,
215
217
pro_mode_reference_docs_storage_container_sas_url : str = "" ,
216
218
pro_mode_reference_docs_storage_container_path_prefix : str = "" ,
217
- ):
219
+ ) -> Response :
218
220
"""
219
221
Initiates the creation of an analyzer with the given ID and schema.
220
222
@@ -269,7 +271,7 @@ def begin_create_analyzer(
269
271
self ._logger .info (f"Analyzer { analyzer_id } create request accepted." )
270
272
return response
271
273
272
- def delete_analyzer (self , analyzer_id : str ):
274
+ def delete_analyzer (self , analyzer_id : str ) -> Response :
273
275
"""
274
276
Deletes an analyzer with the specified analyzer ID.
275
277
@@ -290,7 +292,7 @@ def delete_analyzer(self, analyzer_id: str):
290
292
self ._logger .info (f"Analyzer { analyzer_id } deleted." )
291
293
return response
292
294
293
- def begin_analyze (self , analyzer_id : str , file_location : str ):
295
+ def begin_analyze (self , analyzer_id : str , file_location : str ) -> Response :
294
296
"""
295
297
Begins the analysis of a file or URL using the specified analyzer.
296
298
@@ -357,26 +359,32 @@ def begin_analyze(self, analyzer_id: str, file_location: str):
357
359
)
358
360
return response
359
361
360
- def get_analyze_result (self , file_location : str ):
362
+ def get_analyze_result (self , file_location : str ) -> Dict [ str , Any ] :
361
363
response = self .begin_analyze (
362
364
analyzer_id = self .PREBUILT_DOCUMENT_ANALYZER_ID ,
363
365
file_location = file_location ,
364
366
)
365
367
366
368
return self .poll_result (response , timeout_seconds = 360 )
367
369
368
- async def _upload_file_to_blob (self , container_client : ContainerClient , file_path : str , target_blob_path : str ):
370
+ async def _upload_file_to_blob (
371
+ self , container_client : ContainerClient , file_path : str , target_blob_path : str
372
+ ) -> None :
369
373
with open (file_path , "rb" ) as data :
370
374
await container_client .upload_blob (name = target_blob_path , data = data , overwrite = True )
371
375
self ._logger .info (f"Uploaded file to { target_blob_path } " )
372
376
373
- async def _upload_json_to_blob (self , container_client : ContainerClient , data : dict , target_blob_path : str ):
377
+ async def _upload_json_to_blob (
378
+ self , container_client : ContainerClient , data : Dict [str , Any ], target_blob_path : str
379
+ ) -> None :
374
380
json_str = json .dumps (data , indent = 4 )
375
381
json_bytes = json_str .encode ('utf-8' )
376
382
await container_client .upload_blob (name = target_blob_path , data = json_bytes , overwrite = True )
377
383
self ._logger .info (f"Uploaded json to { target_blob_path } " )
378
384
379
- async def upload_jsonl_to_blob (self , container_client : ContainerClient , data_list : list [dict ], target_blob_path : str ):
385
+ async def upload_jsonl_to_blob (
386
+ self , container_client : ContainerClient , data_list : List [Dict [str , Any ]], target_blob_path : str
387
+ ) -> None :
380
388
jsonl_string = "\n " .join (json .dumps (record ) for record in data_list )
381
389
jsonl_bytes = jsonl_string .encode ("utf-8" )
382
390
await container_client .upload_blob (name = target_blob_path , data = jsonl_bytes , overwrite = True )
@@ -388,7 +396,7 @@ async def generate_knowledge_base_on_blob(
388
396
storage_container_sas_url : str ,
389
397
storage_container_path_prefix : str ,
390
398
skip_analyze : bool = False ,
391
- ):
399
+ ) -> None :
392
400
container_client = ContainerClient .from_container_url (storage_container_sas_url )
393
401
resources = []
394
402
for dirpath , _ , filenames in os .walk (referemce_docs_folder ):
@@ -425,7 +433,7 @@ async def generate_knowledge_base_on_blob(
425
433
426
434
def get_image_from_analyze_operation (
427
435
self , analyze_response : Response , image_id : str
428
- ):
436
+ ) -> Optional [ bytes ] :
429
437
"""Retrieves an image from the analyze operation using the image ID.
430
438
Args:
431
439
analyze_response (Response): The response object from the analyze operation.
@@ -456,8 +464,8 @@ def get_image_from_analyze_operation(
456
464
def begin_create_classifier (
457
465
self ,
458
466
classifier_id : str ,
459
- classifier_schema : dict ,
460
- ):
467
+ classifier_schema : Dict [ str , Any ] ,
468
+ ) -> Response :
461
469
"""
462
470
Initiates the creation of an classifier with the given ID and schema.
463
471
@@ -490,7 +498,7 @@ def begin_create_classifier(
490
498
self ._logger .info (f"Classifier { classifier_id } create request accepted." )
491
499
return response
492
500
493
- def begin_classify (self , classifier_id : str , file_location : str ):
501
+ def begin_classify (self , classifier_id : str , file_location : str ) -> Response :
494
502
"""
495
503
Begins the analysis of a file or URL using the specified classifier.
496
504
@@ -545,7 +553,7 @@ def poll_result(
545
553
response : Response ,
546
554
timeout_seconds : int = 120 ,
547
555
polling_interval_seconds : int = 2 ,
548
- ):
556
+ ) -> Dict [ str , Any ] :
549
557
"""
550
558
Polls the result of an asynchronous operation until it completes or times out.
551
559
0 commit comments