Skip to content

Commit f09fc53

Browse files
FEATURE: Added support to call single-pass-extraction (#16)
Added support to call single-pass-extraction Co-authored-by: Arun Venkataswamy <[email protected]>
1 parent 4a5b9f7 commit f09fc53

File tree

1 file changed

+26
-22
lines changed

1 file changed

+26
-22
lines changed

src/unstract/sdk/prompt.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Any, Optional
22

33
import requests
4+
from requests import RequestException, Response
5+
46
from unstract.sdk.constants import LogLevel, PromptStudioKeys, ToolEnv
57
from unstract.sdk.helper import SdkHelper
68
from unstract.sdk.tool.base import BaseTool
@@ -28,7 +30,15 @@ def __init__(
2830
)
2931
self.bearer_token = tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
3032

31-
def answer_prompt(self, payload: dict) -> dict:
33+
def answer_prompt(self, payload: dict[str, Any]) -> dict[str, Any]:
34+
return self._post_call("answer-prompt", payload)
35+
36+
def single_pass_extraction(self, payload: dict[str, Any]) -> dict[str, Any]:
37+
return self._post_call("single-pass-extraction", payload)
38+
39+
def _post_call(
40+
self, url_path: str, payload: dict[str, Any]
41+
) -> dict[str, Any]:
3242
"""Invokes and communicates to prompt service to fetch response for the
3343
prompt.
3444
@@ -46,36 +56,30 @@ def answer_prompt(self, payload: dict) -> dict:
4656
structure_output : {}
4757
}
4858
"""
49-
result = {
59+
result: dict[str, Any] = {
5060
"status": "ERROR",
5161
"error": "",
5262
"cost": 0,
5363
"structure_output": "",
5464
}
55-
# TODO : Implement authorization for prompt service
56-
# headers = {"Authorization": f"Bearer {self.bearer_token}"}
57-
# Upload file to platform
58-
url = f"{self.base_url}/answer-prompt"
59-
headers = {"Authorization": f"Bearer {self.bearer_token}"}
65+
url: str = f"{self.base_url}/{url_path}"
66+
headers: dict[str, str] = {
67+
"Authorization": f"Bearer {self.bearer_token}"
68+
}
6069
try:
61-
response = requests.post(url, json=payload, headers=headers)
62-
if response.status_code != 200:
63-
self.tool.stream_log(
64-
f"Error while fetching response: {response.text}",
65-
level=LogLevel.ERROR,
66-
)
67-
result["error"] = response.text
68-
return result
69-
else:
70-
result["status"] = "OK"
71-
result["structure_output"] = response.text
72-
except Exception as e:
70+
# TODO: Review timeout value
71+
response: Response = requests.post(
72+
url, json=payload, headers=headers, timeout=600
73+
)
74+
response.raise_for_status()
75+
result["status"] = "OK"
76+
result["structure_output"] = response.text
77+
except RequestException as e:
78+
result["error"] = f"Error occurred: {e}"
7379
self.tool.stream_log(
74-
f"Error while fetching response for prompt: {e}",
80+
f"Error while fetching response for prompt: {result['error']}",
7581
level=LogLevel.ERROR,
7682
)
77-
result["error"] = str(e)
78-
return result
7983
return result
8084

8185
@staticmethod

0 commit comments

Comments
 (0)