Skip to content

Commit 7c314fd

Browse files
add typing for client
1 parent 25df7b7 commit 7c314fd

File tree

1 file changed

+39
-31
lines changed

1 file changed

+39
-31
lines changed

python/content_understanding_client.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,20 @@
66
import time
77

88
from requests.models import Response
9-
from typing import Any
9+
from typing import Any, Callable, Optional, Dict, List, Union
1010
from pathlib import Path
1111

1212
from azure.storage.blob.aio import ContainerClient
1313

1414

1515
class AzureContentUnderstandingClient:
1616

17-
PREBUILT_DOCUMENT_ANALYZER_ID = "prebuilt-documentAnalyzer"
18-
RESULT_SUFFIX = ".result.json"
19-
SOURCES_JSONL = "sources.jsonl"
17+
PREBUILT_DOCUMENT_ANALYZER_ID: str = "prebuilt-documentAnalyzer"
18+
RESULT_SUFFIX: str = ".result.json"
19+
SOURCES_JSONL: str = "sources.jsonl"
2020

2121
# https://learn.microsoft.com/en-us/azure/ai-services/content-understanding/service-limits#document-and-text
22-
SUPPORTED_FILE_TYPES = [
22+
SUPPORTED_FILE_TYPES: List[str] = [
2323
".pdf",
2424
".tiff",
2525
".jpg",
@@ -38,7 +38,7 @@ class AzureContentUnderstandingClient:
3838
".xml",
3939
]
4040

41-
SUPPORTED_FILE_TYPES_PRO_MODE = [
41+
SUPPORTED_FILE_TYPES_PRO_MODE: List[str] = [
4242
".pdf",
4343
".tiff",
4444
".jpg",
@@ -73,41 +73,43 @@ def __init__(
7373

7474
self._headers = self._get_headers(subscription_key, token, x_ms_useragent)
7575

76-
def _get_analyzer_url(self, endpoint, api_version, analyzer_id):
76+
def _get_analyzer_url(self, endpoint: str, api_version: str, analyzer_id: str) -> str:
7777
return f"{endpoint}/contentunderstanding/analyzers/{analyzer_id}?api-version={api_version}" # noqa
7878

79-
def _get_analyzer_list_url(self, endpoint, api_version):
79+
def _get_analyzer_list_url(self, endpoint: str, api_version: str) -> str:
8080
return f"{endpoint}/contentunderstanding/analyzers?api-version={api_version}"
8181

82-
def _get_analyze_url(self, endpoint, api_version, analyzer_id):
82+
def _get_analyze_url(self, endpoint: str, api_version: str, analyzer_id: str) -> str:
8383
return f"{endpoint}/contentunderstanding/analyzers/{analyzer_id}:analyze?api-version={api_version}" # noqa
8484

8585
def _get_training_data_config(
86-
self, storage_container_sas_url, storage_container_path_prefix
87-
):
86+
self, storage_container_sas_url: str, storage_container_path_prefix: str
87+
) -> Dict[str, str]:
8888
return {
8989
"containerUrl": storage_container_sas_url,
9090
"kind": "blob",
9191
"prefix": storage_container_path_prefix,
9292
}
9393

9494
def _get_pro_mode_reference_docs_config(
95-
self, storage_container_sas_url, storage_container_path_prefix
96-
):
95+
self, storage_container_sas_url: str, storage_container_path_prefix: str
96+
) -> List[Dict[str, str]]:
9797
return [{
9898
"kind": "reference",
9999
"containerUrl": storage_container_sas_url,
100100
"prefix": storage_container_path_prefix,
101101
"fileListPath": self.SOURCES_JSONL,
102102
}]
103103

104-
def _get_classifier_url(self, endpoint, api_version, classifier_id):
104+
def _get_classifier_url(self, endpoint: str, api_version: str, classifier_id: str) -> str:
105105
return f"{endpoint}/contentunderstanding/classifiers/{classifier_id}?api-version={api_version}"
106106

107-
def _get_classify_url(self, endpoint, api_version, classifier_id):
107+
def _get_classify_url(self, endpoint: str, api_version: str, classifier_id: str) -> str:
108108
return f"{endpoint}/contentunderstanding/classifiers/{classifier_id}:classify?api-version={api_version}"
109109

110-
def _get_headers(self, subscription_key, api_token, x_ms_useragent):
110+
def _get_headers(
111+
self, subscription_key: str, api_token: str, x_ms_useragent: str
112+
) -> Dict[str, str]:
111113
"""Returns the headers for the HTTP requests.
112114
Args:
113115
subscription_key (str): The subscription key for the service.
@@ -163,7 +165,7 @@ def is_supported_type_by_file_ext(file_ext: str, is_pro_mode: bool=False) -> boo
163165
)
164166
return file_ext.lower() in supported_types
165167

166-
def get_all_analyzers(self):
168+
def get_all_analyzers(self) -> Dict[str, Any]:
167169
"""
168170
Retrieves a list of all available analyzers from the content understanding service.
169171
@@ -184,7 +186,7 @@ def get_all_analyzers(self):
184186
response.raise_for_status()
185187
return response.json()
186188

187-
def get_analyzer_detail_by_id(self, analyzer_id):
189+
def get_analyzer_detail_by_id(self, analyzer_id: str) -> Dict[str, Any]:
188190
"""
189191
Retrieves a specific analyzer detail through analyzerid from the content understanding service.
190192
This method sends a GET request to the service endpoint to get the analyzer detail.
@@ -214,7 +216,7 @@ def begin_create_analyzer(
214216
training_storage_container_path_prefix: str = "",
215217
pro_mode_reference_docs_storage_container_sas_url: str = "",
216218
pro_mode_reference_docs_storage_container_path_prefix: str = "",
217-
):
219+
) -> Response:
218220
"""
219221
Initiates the creation of an analyzer with the given ID and schema.
220222
@@ -269,7 +271,7 @@ def begin_create_analyzer(
269271
self._logger.info(f"Analyzer {analyzer_id} create request accepted.")
270272
return response
271273

272-
def delete_analyzer(self, analyzer_id: str):
274+
def delete_analyzer(self, analyzer_id: str) -> Response:
273275
"""
274276
Deletes an analyzer with the specified analyzer ID.
275277
@@ -290,7 +292,7 @@ def delete_analyzer(self, analyzer_id: str):
290292
self._logger.info(f"Analyzer {analyzer_id} deleted.")
291293
return response
292294

293-
def begin_analyze(self, analyzer_id: str, file_location: str):
295+
def begin_analyze(self, analyzer_id: str, file_location: str) -> Response:
294296
"""
295297
Begins the analysis of a file or URL using the specified analyzer.
296298
@@ -357,26 +359,32 @@ def begin_analyze(self, analyzer_id: str, file_location: str):
357359
)
358360
return response
359361

360-
def get_analyze_result(self, file_location: str):
362+
def get_analyze_result(self, file_location: str) -> Dict[str, Any]:
361363
response = self.begin_analyze(
362364
analyzer_id=self.PREBUILT_DOCUMENT_ANALYZER_ID,
363365
file_location=file_location,
364366
)
365367

366368
return self.poll_result(response, timeout_seconds=360)
367369

368-
async def _upload_file_to_blob(self, container_client: ContainerClient, file_path: str, target_blob_path: str):
370+
async def _upload_file_to_blob(
371+
self, container_client: ContainerClient, file_path: str, target_blob_path: str
372+
) -> None:
369373
with open(file_path, "rb") as data:
370374
await container_client.upload_blob(name=target_blob_path, data=data, overwrite=True)
371375
self._logger.info(f"Uploaded file to {target_blob_path}")
372376

373-
async def _upload_json_to_blob(self, container_client: ContainerClient, data: dict, target_blob_path: str):
377+
async def _upload_json_to_blob(
378+
self, container_client: ContainerClient, data: Dict[str, Any], target_blob_path: str
379+
) -> None:
374380
json_str = json.dumps(data, indent=4)
375381
json_bytes = json_str.encode('utf-8')
376382
await container_client.upload_blob(name=target_blob_path, data=json_bytes, overwrite=True)
377383
self._logger.info(f"Uploaded json to {target_blob_path}")
378384

379-
async def upload_jsonl_to_blob(self, container_client: ContainerClient, data_list: list[dict], target_blob_path: str):
385+
async def upload_jsonl_to_blob(
386+
self, container_client: ContainerClient, data_list: List[Dict[str, Any]], target_blob_path: str
387+
) -> None:
380388
jsonl_string = "\n".join(json.dumps(record) for record in data_list)
381389
jsonl_bytes = jsonl_string.encode("utf-8")
382390
await container_client.upload_blob(name=target_blob_path, data=jsonl_bytes, overwrite=True)
@@ -388,7 +396,7 @@ async def generate_knowledge_base_on_blob(
388396
storage_container_sas_url: str,
389397
storage_container_path_prefix: str,
390398
skip_analyze: bool = False,
391-
):
399+
) -> None:
392400
container_client = ContainerClient.from_container_url(storage_container_sas_url)
393401
resources = []
394402
for dirpath, _, filenames in os.walk(referemce_docs_folder):
@@ -425,7 +433,7 @@ async def generate_knowledge_base_on_blob(
425433

426434
def get_image_from_analyze_operation(
427435
self, analyze_response: Response, image_id: str
428-
):
436+
) -> Optional[bytes]:
429437
"""Retrieves an image from the analyze operation using the image ID.
430438
Args:
431439
analyze_response (Response): The response object from the analyze operation.
@@ -456,8 +464,8 @@ def get_image_from_analyze_operation(
456464
def begin_create_classifier(
457465
self,
458466
classifier_id: str,
459-
classifier_schema: dict,
460-
):
467+
classifier_schema: Dict[str, Any],
468+
) -> Response:
461469
"""
462470
Initiates the creation of an classifier with the given ID and schema.
463471
@@ -490,7 +498,7 @@ def begin_create_classifier(
490498
self._logger.info(f"Classifier {classifier_id} create request accepted.")
491499
return response
492500

493-
def begin_classify(self, classifier_id: str, file_location: str):
501+
def begin_classify(self, classifier_id: str, file_location: str) -> Response:
494502
"""
495503
Begins the analysis of a file or URL using the specified classifier.
496504
@@ -545,7 +553,7 @@ def poll_result(
545553
response: Response,
546554
timeout_seconds: int = 120,
547555
polling_interval_seconds: int = 2,
548-
):
556+
) -> Dict[str, Any]:
549557
"""
550558
Polls the result of an asynchronous operation until it completes or times out.
551559

0 commit comments

Comments
 (0)