Skip to content

Commit 4acfa68

Browse files
fix: Error handling improvement for gemini (#84)
* Error handling fix for gemini, bumped SDK to 0.41.1 * Update src/unstract/sdk/adapters/llm/vertex_ai/src/vertex_ai.py Co-authored-by: Deepak K <[email protected]> Signed-off-by: Chandrasekharan M <[email protected]> * Minor fixes on gemini LLM error handling --------- Signed-off-by: Chandrasekharan M <[email protected]> Co-authored-by: Deepak K <[email protected]>
1 parent ceec5d4 commit 4acfa68

File tree

11 files changed

+166
-42
lines changed

11 files changed

+166
-42
lines changed

src/unstract/sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.41.0"
1+
__version__ = "0.41.1"
22

33

44
def get_sdk_version():

src/unstract/sdk/adapter.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def get_adapter_configuration(
5858
if response.status_code == 200:
5959
adapter_data: dict[str, Any] = response.json()
6060

61+
# TODO: Print config after redacting sensitive information
6162
self.tool.stream_log(
6263
"Successfully retrieved adapter config "
6364
f"for adapter: {adapter_instance_id}"
@@ -104,18 +105,16 @@ def get_adapter_config(
104105
Any: engine
105106
"""
106107
# Check if the adapter ID matches any public adapter keys
107-
if SdkHelper.is_public_adapter(
108-
adapter_id=adapter_instance_id
109-
):
110-
adapter_metadata_config = tool.get_env_or_die(
111-
adapter_instance_id
112-
)
108+
if SdkHelper.is_public_adapter(adapter_id=adapter_instance_id):
109+
adapter_metadata_config = tool.get_env_or_die(adapter_instance_id)
113110
adapter_metadata = json.loads(adapter_metadata_config)
114111
return adapter_metadata
115112
platform_host = tool.get_env_or_die(ToolEnv.PLATFORM_HOST)
116113
platform_port = tool.get_env_or_die(ToolEnv.PLATFORM_PORT)
117114

118-
tool.stream_log("Connecting to DB and getting table metadata")
115+
tool.stream_log(
116+
f"Connecting to DB and getting table metadata for {adapter_instance_id}"
117+
)
119118
tool_adapter = ToolAdapter(
120119
tool=tool,
121120
platform_host=platform_host,

src/unstract/sdk/adapters/exceptions.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
1-
from unstract.sdk.adapters.constants import Common
1+
from unstract.sdk.exceptions import SdkError
22

33

4-
class AdapterError(Exception):
5-
def __init__(self, message: str = Common.DEFAULT_ERR_MESSAGE):
6-
super().__init__(message)
7-
# Make it user friendly wherever possible
8-
self.message = message
9-
10-
def __str__(self) -> str:
11-
return self.message
4+
class AdapterError(SdkError):
5+
pass
126

137

148
class LLMError(AdapterError):
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from openai import APIError as OpenAIAPIError
2+
from vertexai.generative_models import ResponseValidationError
3+
4+
from unstract.sdk.adapters.exceptions import LLMError
5+
from unstract.sdk.adapters.llm.open_ai.src.open_ai import OpenAILLM
6+
from unstract.sdk.adapters.llm.vertex_ai.src.vertex_ai import VertexAILLM
7+
8+
9+
def parse_llm_err(e: Exception) -> LLMError:
10+
"""Parses the exception from LLM provider.
11+
12+
Helps parse the LLM error and wraps it with our
13+
custom exception object to contain a user friendly message.
14+
15+
Args:
16+
e (Exception): Error from LLM provider
17+
18+
Returns:
19+
LLMError: Unstract's LLMError object
20+
"""
21+
if isinstance(e, ResponseValidationError):
22+
return VertexAILLM.parse_llm_err(e)
23+
elif isinstance(e, OpenAIAPIError):
24+
return OpenAILLM.parse_llm_err(e)
25+
return LLMError(str(e))

src/unstract/sdk/adapters/llm/llm_adapter.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from unstract.sdk.adapters.base import Adapter
88
from unstract.sdk.adapters.enums import AdapterTypes
9+
from unstract.sdk.adapters.exceptions import LLMError
910

1011
logger = logging.getLogger(__name__)
1112

@@ -44,6 +45,20 @@ def get_json_schema() -> str:
4445
def get_adapter_type() -> AdapterTypes:
4546
return AdapterTypes.LLM
4647

48+
@staticmethod
49+
def parse_llm_err(e: Exception) -> LLMError:
50+
"""Parse the error from an LLM provider.
51+
52+
Helps parse errors from a provider and wraps with custom exception.
53+
54+
Args:
55+
e (Exception): Exception from LLM provider
56+
57+
Returns:
58+
LLMError: Error to be sent to the user
59+
"""
60+
return LLMError(str(e))
61+
4762
def get_llm_instance(self) -> LLM:
4863
"""Instantiate the llama index LLM class.
4964

src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33

44
from llama_index.core.llms import LLM
55
from llama_index.llms.openai import OpenAI
6+
from openai import APIError as OpenAIAPIError
7+
from openai import RateLimitError as OpenAIRateLimitError
68

79
from unstract.sdk.adapters.exceptions import AdapterError
810
from unstract.sdk.adapters.llm.constants import LLMKeys
911
from unstract.sdk.adapters.llm.helper import LLMHelper
1012
from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter
13+
from unstract.sdk.exceptions import LLMError, RateLimitError
1114

1215

1316
class Constants:
@@ -76,3 +79,24 @@ def test_connection(self) -> bool:
7679
llm = self.get_llm_instance()
7780
test_result: bool = LLMHelper.test_llm_instance(llm=llm)
7881
return test_result
82+
83+
@staticmethod
84+
def parse_llm_err(e: OpenAIAPIError) -> LLMError:
85+
"""Parse the error from Open AI.
86+
87+
Helps parse errors from Open AI and wraps with custom exception.
88+
89+
Args:
90+
e (OpenAIAPIError): Exception from Open AI
91+
92+
Returns:
93+
LLMError: Error to be sent to the user
94+
"""
95+
msg = "OpenAI error: "
96+
if hasattr(e, "body") and "message" in e.body:
97+
msg += e.body["message"]
98+
else:
99+
msg += e.message
100+
if isinstance(e, OpenAIRateLimitError):
101+
return RateLimitError(msg)
102+
return LLMError(msg)

src/unstract/sdk/adapters/llm/vertex_ai/src/vertex_ai.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from google.oauth2.service_account import Credentials
88
from llama_index.core.llms import LLM
99
from llama_index.llms.vertex import Vertex
10+
from vertexai.generative_models import Candidate, FinishReason, ResponseValidationError
1011
from vertexai.generative_models._generative_models import (
1112
HarmBlockThreshold,
1213
HarmCategory,
@@ -191,3 +192,62 @@ def test_connection(self) -> bool:
191192
except Exception as e:
192193
raise LLMError(f"Error while testing connection for VertexAI: {str(e)}")
193194
return test_result
195+
196+
@staticmethod
197+
def parse_llm_err(e: ResponseValidationError) -> LLMError:
198+
"""Parse the error from Vertex AI.
199+
200+
Helps parse and raise errors from Vertex AI.
201+
https://ai.google.dev/api/generate-content#generatecontentresponse
202+
203+
Args:
204+
e (ResponseValidationError): Exception from Vertex AI
205+
206+
Returns:
207+
LLMError: Error to be sent to the user
208+
"""
209+
assert len(e.responses) == 1, (
210+
"Expected e.responses to contain a single element "
211+
"since its a completion call and not chat."
212+
)
213+
resp = e.responses[0]
214+
candidates: list["Candidate"] = resp.candidates
215+
if not candidates:
216+
msg = str(resp.prompt_feedback)
217+
reason_messages = {
218+
FinishReason.MAX_TOKENS: (
219+
"The maximum number of tokens for the LLM has been reached. Please "
220+
"either tweak your prompts or try using another LLM."
221+
),
222+
FinishReason.STOP: (
223+
"The LLM stopped generating a response due to the natural stop "
224+
"point of the model or a provided stop sequence."
225+
),
226+
FinishReason.SAFETY: "The LLM response was flagged for safety reasons.",
227+
FinishReason.RECITATION: (
228+
"The LLM response was flagged for recitation reasons."
229+
),
230+
FinishReason.BLOCKLIST: (
231+
"The LLM response generation was stopped because it "
232+
"contains forbidden terms."
233+
),
234+
FinishReason.PROHIBITED_CONTENT: (
235+
"The LLM response generation was stopped because it "
236+
"potentially contains prohibited content."
237+
),
238+
FinishReason.SPII: (
239+
"The LLM response generation was stopped because it potentially "
240+
"contains Sensitive Personally Identifiable Information."
241+
),
242+
}
243+
244+
err_list = []
245+
for candidate in candidates:
246+
reason: FinishReason = candidate.finish_reason
247+
if candidate.finish_message:
248+
err_msg = candidate.finish_message
249+
else:
250+
err_msg = reason_messages.get(reason, str(candidate))
251+
err_list.append(err_msg)
252+
msg = "\n\nAnother error: \n".join(err_list)
253+
return LLMError(msg)

src/unstract/sdk/llm.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from unstract.sdk.adapter import ToolAdapter
1212
from unstract.sdk.adapters.constants import Common
1313
from unstract.sdk.adapters.llm import adapters
14+
from unstract.sdk.adapters.llm.exceptions import parse_llm_err
1415
from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter
1516
from unstract.sdk.constants import LogLevel, ToolEnv
1617
from unstract.sdk.exceptions import LLMError, RateLimitError, SdkError
@@ -56,9 +57,7 @@ def _initialise(self):
5657
self._llm_instance = self._get_llm(self._adapter_instance_id)
5758
self._usage_kwargs["adapter_instance_id"] = self._adapter_instance_id
5859

59-
if not SdkHelper.is_public_adapter(
60-
adapter_id=self._adapter_instance_id
61-
):
60+
if not SdkHelper.is_public_adapter(adapter_id=self._adapter_instance_id):
6261
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
6362
CallbackManager.set_callback(
6463
platform_api_key=platform_api_key,
@@ -78,16 +77,8 @@ def complete(
7877
if match:
7978
response.text = match.group(0)
8079
return {LLM.RESPONSE: response}
81-
# TODO: Handle for all LLM providers
82-
except OpenAIAPIError as e:
83-
msg = "OpenAI error: "
84-
if hasattr(e, "body") and "message" in e.body:
85-
msg += e.body["message"]
86-
else:
87-
msg += e.message
88-
if isinstance(e, OpenAIRateLimitError):
89-
raise RateLimitError(msg)
90-
raise LLMError(msg) from e
80+
except Exception as e:
81+
raise parse_llm_err(e) from e
9182

9283
def _get_llm(self, adapter_instance_id: str) -> LlamaIndexLLM:
9384
"""Returns the LLM object for the tool.

src/unstract/sdk/tool/executor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,22 @@ def execute_run(self, args: argparse.Namespace) -> None:
4949
self.tool.stream_error_and_exit("--settings are required for RUN command")
5050
settings: dict[str, Any] = loads(args.settings)
5151

52-
self._setup_for_run()
52+
self.tool.stream_log(
53+
f"Running tool with "
54+
f"Workflow ID: {self.tool.workflow_id}, "
55+
f"Execution ID: {self.tool.execution_id}, "
56+
f"SDK Version: {get_sdk_version()}"
57+
)
5358

59+
self._setup_for_run()
5460
validator = ToolValidator(self.tool)
5561
settings = validator.validate_pre_execution(settings=settings)
5662

5763
self.tool.stream_log(
58-
f"Running tool for "
59-
f"Workflow ID: {self.tool.workflow_id}, "
60-
f"Execution ID: {self.tool.execution_id}, "
61-
f"SDK Version: {get_sdk_version()}, "
64+
f"Executing for file: {self.tool.get_exec_metadata['source_name']}, "
65+
f"with tool settings: {settings}"
6266
)
67+
6368
try:
6469
self.tool.run(
6570
settings=settings,

src/unstract/sdk/tool/stream.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from deprecated import deprecated
77

88
from unstract.sdk.constants import Command, LogLevel, LogStage, ToolEnv
9+
from unstract.sdk.utils import ToolUtils
910

1011

1112
class StreamMixin:
@@ -26,7 +27,7 @@ def __init__(self, log_level: LogLevel = LogLevel.INFO, **kwargs) -> None:
2627
2728
"""
2829
self.log_level = log_level
29-
self._exec_by_tool = bool(
30+
self._exec_by_tool = ToolUtils.str_to_bool(
3031
os.environ.get(ToolEnv.EXECUTION_BY_TOOL, "False")
3132
)
3233
super().__init__(**kwargs)
@@ -78,9 +79,7 @@ def stream_error_and_exit(self, message: str) -> None:
7879
if self._exec_by_tool:
7980
exit(1)
8081
else:
81-
raise RuntimeError(
82-
"RuntimeError from SDK, check the above log for details"
83-
)
82+
raise RuntimeError("RuntimeError from SDK, check the above log for details")
8483

8584
def get_env_or_die(self, env_key: str) -> str:
8685
"""Returns the value of an env variable.
@@ -232,9 +231,7 @@ def stream_single_step_message(message: str, **kwargs: Any) -> None:
232231
print(json.dumps(record))
233232

234233
@staticmethod
235-
@deprecated(
236-
version="0.4.4", reason="Use `BaseTool.write_to_result()` instead"
237-
)
234+
@deprecated(version="0.4.4", reason="Use `BaseTool.write_to_result()` instead")
238235
def stream_result(result: dict[Any, Any], **kwargs: Any) -> None:
239236
"""Streams the result of the tool using the Unstract protocol RESULT to
240237
stdout.

0 commit comments

Comments
 (0)