Skip to content

Commit aa3e7e2

Browse files
committed
Implemented draft versions of both the structure_file() and check_execution_status() functions.
1 parent 554f938 commit aa3e7e2

File tree

2 files changed

+206
-0
lines changed

2 files changed

+206
-0
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""
2+
This module provides an API client to invoke APIs deployed on the Unstract platform.
3+
4+
Classes:
5+
APIClient: A class to invoke APIs deployed on the Unstract platform.
6+
APIClientException: A class to handle exceptions raised by the APIClient class.
7+
"""
8+
import logging
9+
import os
10+
11+
import requests
12+
import ntpath
13+
from urllib.parse import urlparse, parse_qs
14+
15+
from src.unstract.api_deployments.utils import UnstractUtils
16+
17+
18+
class APIDeploymentsClientException(Exception):
19+
"""
20+
A class to handle exceptions raised by the APIClient class.
21+
"""
22+
23+
def __init__(self, message):
24+
def __init__(self, value):
25+
self.value = value
26+
27+
def __str__(self):
28+
return repr(self.value)
29+
30+
def error_message(self):
31+
return self.value
32+
33+
34+
class APIDeploymentsClient:
35+
"""
36+
A class to invoke APIs deployed on the Unstract platform.
37+
"""
38+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
39+
logger = logging.getLogger(__name__)
40+
log_stream_handler = logging.StreamHandler()
41+
log_stream_handler.setFormatter(formatter)
42+
logger.addHandler(log_stream_handler)
43+
44+
api_key = ""
45+
api_timeout = 300
46+
47+
def __init__(
48+
self,
49+
api_url: str,
50+
api_key: str,
51+
api_timeout: int = 300,
52+
logging_level: str = "INFO",
53+
):
54+
"""
55+
Initializes the APIClient class.
56+
57+
Args:
58+
api_key (str): The API key to authenticate the API request.
59+
api_timeout (int): The timeout to wait for the API response.
60+
logging_level (str): The logging level to log messages.
61+
"""
62+
if logging_level == "":
63+
logging_level = os.getenv("UNSTRACT_API_CLIENT_LOGGING_LEVEL", "INFO")
64+
if logging_level == "DEBUG":
65+
self.logger.setLevel(logging.DEBUG)
66+
elif logging_level == "INFO":
67+
self.logger.setLevel(logging.INFO)
68+
elif logging_level == "WARNING":
69+
self.logger.setLevel(logging.WARNING)
70+
elif logging_level == "ERROR":
71+
self.logger.setLevel(logging.ERROR)
72+
73+
# self.logger.setLevel(logging_level)
74+
self.logger.debug("Logging level set to: " + logging_level)
75+
76+
if api_key == "":
77+
self.api_key = os.getenv("UNSTRACT_API_DEPLOYMENT_KEY")
78+
else:
79+
self.api_key = api_key
80+
self.logger.debug("API key set to: " + UnstractUtils.redact_key(self.api_key))
81+
82+
self.api_timeout = api_timeout
83+
self.api_url = api_url
84+
self.__save_base_url(api_url)
85+
86+
def __save_base_url(self, full_url: str):
87+
"""
88+
Extracts the base URL from the full URL and saves it.
89+
90+
Args:
91+
full_url (str): The full URL of the API.
92+
"""
93+
parsed_url = urlparse(full_url)
94+
self.base_url = parsed_url.scheme + "://" + parsed_url.netloc
95+
self.logger.debug("Base URL: " + self.base_url)
96+
97+
def structure_file(self, file_paths: list[str]) -> dict:
98+
"""
99+
Invokes the API deployed on the Unstract platform.
100+
101+
Args:
102+
file_paths (list[str]): The file path to the file to be uploaded.
103+
104+
Returns:
105+
dict: The response from the API.
106+
"""
107+
self.logger.debug("Invoking API: " + self.api_url)
108+
self.logger.debug("File paths: " + str(file_paths))
109+
110+
headers = {
111+
"Authorization": "Bearer " + self.api_key,
112+
}
113+
114+
data = {
115+
"timeout": self.api_timeout
116+
}
117+
118+
files = []
119+
120+
try:
121+
for file_path in file_paths:
122+
record = ('files', (ntpath.basename(file_path), open(file_path, "rb"), 'application/octet-stream'))
123+
files.append(record)
124+
except FileNotFoundError as e:
125+
raise APIDeploymentsClientException("File not found: " + str(e))
126+
127+
response = requests.post(
128+
self.api_url,
129+
headers=headers,
130+
data=data,
131+
files=files,
132+
)
133+
self.logger.debug(response.status_code)
134+
self.logger.debug(response.text)
135+
136+
# The returned object is wrapped in a "message" key. Let's simplify the response.
137+
obj_to_return = {"pending": False, "execution_status": response.json()["message"]["execution_status"],
138+
"error": response.json()["message"]["error"],
139+
"extraction_result": response.json()["message"]["result"],
140+
"status_code": response.status_code}
141+
142+
# If the execution status is pending, extract the execution ID from the response
143+
# and return it in the response. Later, users can use the execution ID to check the status of the execution.
144+
if 200 <= response.status_code < 300 and obj_to_return["execution_status"] == "PENDING":
145+
obj_to_return["status_check_api_endpoint"] = response.json()["message"]["status_api"]
146+
obj_to_return["pending"] = True
147+
148+
return obj_to_return
149+
150+
def check_execution_status(self, status_check_api_endpoint: str) -> dict:
151+
"""
152+
Checks the status of the execution.
153+
154+
Args:
155+
status_check_api_endpoint (str): The API endpoint to check the status of the execution.
156+
157+
Returns:
158+
dict: The response from the API.
159+
"""
160+
161+
headers = {
162+
"Authorization": "Bearer " + self.api_key,
163+
}
164+
status_call_url = self.base_url + status_check_api_endpoint
165+
self.logger.debug("Checking execution status via endpoint: " + status_call_url)
166+
response = requests.get(
167+
status_call_url,
168+
headers=headers,
169+
)
170+
self.logger.debug(response.status_code)
171+
self.logger.debug(response.text)
172+
obj_to_return = {"pending": False, "execution_status": response.json()["status"],
173+
"status_code": response.status_code, "message": response.json()["message"]}
174+
175+
# If the execution status is pending, extract the execution ID from the response
176+
# and return it in the response. Later, users can use the execution ID to check the status of the execution.
177+
if 200 <= response.status_code < 500 and obj_to_return["execution_status"] == "PENDING":
178+
obj_to_return["pending"] = True
179+
180+
if 200 <= response.status_code < 300 and obj_to_return["execution_status"] == "SUCCESS":
181+
obj_to_return["extraction_result"] = response.json()["message"]
182+
183+
return obj_to_return
184+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
class UnstractUtils:
2+
@staticmethod
3+
def redact_key(api_key: str, reveal_length=4) -> str:
4+
"""Hides sensitive information partially. Useful for logging keys.
5+
6+
Args:
7+
api_key (str): API key to redact
8+
reveal_length (int): Number of characters to reveal from the start of the key
9+
10+
Returns:
11+
str: Redacted API key
12+
"""
13+
if not isinstance(api_key, str):
14+
raise ValueError("API key must be a string")
15+
16+
if reveal_length < 0:
17+
raise ValueError("Reveal length must be a non-negative integer")
18+
19+
redacted_length = max(len(api_key) - reveal_length, 0)
20+
revealed_part = api_key[:reveal_length]
21+
redacted_part = "x" * redacted_length
22+
return revealed_part + redacted_part

0 commit comments

Comments
 (0)