1+ import logging
12from typing import Any , Optional
23
34import requests
4- from requests import RequestException , Response
5+ from requests import ConnectionError , RequestException , Response , Timeout
56
67from unstract .sdk .constants import LogLevel , PromptStudioKeys , ToolEnv
78from unstract .sdk .helper import SdkHelper
89from unstract .sdk .tool .base import BaseTool
910
11+ logger = logging .getLogger (__name__ )
12+
1013
1114class PromptTool :
1215 """Class to handle prompt service methods for Unstract Tools."""
@@ -25,9 +28,7 @@ def __init__(
2528
2629 """
2730 self .tool = tool
28- self .base_url = SdkHelper .get_platform_base_url (
29- prompt_host , prompt_port
30- )
31+ self .base_url = SdkHelper .get_platform_base_url (prompt_host , prompt_port )
3132 self .bearer_token = tool .get_env_or_die (ToolEnv .PLATFORM_API_KEY )
3233
3334 def answer_prompt (self , payload : dict [str , Any ]) -> dict [str , Any ]:
@@ -36,9 +37,7 @@ def answer_prompt(self, payload: dict[str, Any]) -> dict[str, Any]:
3637 def single_pass_extraction (self , payload : dict [str , Any ]) -> dict [str , Any ]:
3738 return self ._post_call ("single-pass-extraction" , payload )
3839
39- def _post_call (
40- self , url_path : str , payload : dict [str , Any ]
41- ) -> dict [str , Any ]:
40+ def _post_call (self , url_path : str , payload : dict [str , Any ]) -> dict [str , Any ]:
4241 """Invokes and communicates to prompt service to fetch response for the
4342 prompt.
4443
@@ -63,17 +62,24 @@ def _post_call(
6362 "structure_output" : "" ,
6463 }
6564 url : str = f"{ self .base_url } /{ url_path } "
66- headers : dict [str , str ] = {
67- "Authorization" : f"Bearer { self .bearer_token } "
68- }
65+ headers : dict [str , str ] = {"Authorization" : f"Bearer { self .bearer_token } " }
66+ response : Response = Response ()
6967 try :
70- # TODO: Review timeout value
71- response : Response = requests .post (
72- url , json = payload , headers = headers , timeout = 600
73- )
68+ response = requests .post (url , json = payload , headers = headers , timeout = 600 )
7469 response .raise_for_status ()
7570 result ["status" ] = "OK"
7671 result ["structure_output" ] = response .text
72+ except ConnectionError as connect_err :
73+ msg = "Unable to connect to prompt service. Please contact admin."
74+ self ._stringify_and_stream_err (connect_err , msg )
75+ result ["error" ] = msg
76+ except Timeout as time_out :
77+ msg = (
78+ "Request to run prompt has timed out. "
79+ "Probable causes might be connectivity issues in LLMs."
80+ )
81+ self ._stringify_and_stream_err (time_out , msg )
82+ result ["error" ] = msg
7783 except RequestException as e :
7884 # Extract error information from the response if available
7985 error_message = str (e )
@@ -91,6 +97,12 @@ def _post_call(
9197 )
9298 return result
9399
100+ def _stringify_and_stream_err (self , err : RequestException , msg : str ) -> None :
101+ error_message = str (err )
102+ trace = f"{ msg } : { error_message } "
103+ self .tool .stream_log (trace , level = LogLevel .ERROR )
104+ logger .error (trace )
105+
94106 @staticmethod
95107 def get_exported_tool (
96108 tool : BaseTool , prompt_registry_id : str
0 commit comments