Skip to content

Commit ac751c4

Browse files
[FIX] Exception enhancements for Prompt Service (#52)
* Exception handling for Prompt Service * Optimizing error messages * Update src/unstract/sdk/prompt.py * Error message enhancements. * Removing unused exceptions --------- Signed-off-by: harini-venkataraman <[email protected]>
1 parent 0fabae8 commit ac751c4

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

src/unstract/sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.27.0"
1+
__version__ = "0.27.1"
22

33

44
def get_sdk_version():

src/unstract/sdk/prompt.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import logging
12
from typing import Any, Optional
23

34
import requests
4-
from requests import RequestException, Response
5+
from requests import ConnectionError, RequestException, Response, Timeout
56

67
from unstract.sdk.constants import LogLevel, PromptStudioKeys, ToolEnv
78
from unstract.sdk.helper import SdkHelper
89
from unstract.sdk.tool.base import BaseTool
910

11+
logger = logging.getLogger(__name__)
12+
1013

1114
class PromptTool:
1215
"""Class to handle prompt service methods for Unstract Tools."""
@@ -25,9 +28,7 @@ def __init__(
2528
2629
"""
2730
self.tool = tool
28-
self.base_url = SdkHelper.get_platform_base_url(
29-
prompt_host, prompt_port
30-
)
31+
self.base_url = SdkHelper.get_platform_base_url(prompt_host, prompt_port)
3132
self.bearer_token = tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
3233

3334
def answer_prompt(self, payload: dict[str, Any]) -> dict[str, Any]:
@@ -36,9 +37,7 @@ def answer_prompt(self, payload: dict[str, Any]) -> dict[str, Any]:
3637
def single_pass_extraction(self, payload: dict[str, Any]) -> dict[str, Any]:
3738
return self._post_call("single-pass-extraction", payload)
3839

39-
def _post_call(
40-
self, url_path: str, payload: dict[str, Any]
41-
) -> dict[str, Any]:
40+
def _post_call(self, url_path: str, payload: dict[str, Any]) -> dict[str, Any]:
4241
"""Invokes and communicates to prompt service to fetch response for the
4342
prompt.
4443
@@ -63,17 +62,24 @@ def _post_call(
6362
"structure_output": "",
6463
}
6564
url: str = f"{self.base_url}/{url_path}"
66-
headers: dict[str, str] = {
67-
"Authorization": f"Bearer {self.bearer_token}"
68-
}
65+
headers: dict[str, str] = {"Authorization": f"Bearer {self.bearer_token}"}
66+
response: Response = Response()
6967
try:
70-
# TODO: Review timeout value
71-
response: Response = requests.post(
72-
url, json=payload, headers=headers, timeout=600
73-
)
68+
response = requests.post(url, json=payload, headers=headers, timeout=600)
7469
response.raise_for_status()
7570
result["status"] = "OK"
7671
result["structure_output"] = response.text
72+
except ConnectionError as connect_err:
73+
msg = "Unable to connect to prompt service. Please contact admin."
74+
self._stringify_and_stream_err(connect_err, msg)
75+
result["error"] = msg
76+
except Timeout as time_out:
77+
msg = (
78+
"Request to run prompt has timed out. "
79+
"Probable causes might be connectivity issues in LLMs."
80+
)
81+
self._stringify_and_stream_err(time_out, msg)
82+
result["error"] = msg
7783
except RequestException as e:
7884
# Extract error information from the response if available
7985
error_message = str(e)
@@ -91,6 +97,12 @@ def _post_call(
9197
)
9298
return result
9399

100+
def _stringify_and_stream_err(self, err: RequestException, msg: str) -> None:
101+
error_message = str(err)
102+
trace = f"{msg}: {error_message}"
103+
self.tool.stream_log(trace, level=LogLevel.ERROR)
104+
logger.error(trace)
105+
94106
@staticmethod
95107
def get_exported_tool(
96108
tool: BaseTool, prompt_registry_id: str

0 commit comments

Comments
 (0)