11import logging
22import re
3- import time
43from typing import Any , Optional
54
65from llama_index .core .llms import LLM , CompletionResponse
6+ from openai import APIError as OpenAIAPIError
7+ from openai import RateLimitError as OpenAIRateLimitError
78from unstract .adapters .constants import Common
89from unstract .adapters .llm import adapters
910from unstract .adapters .llm .llm_adapter import LLMAdapter
1011
1112from unstract .sdk .adapters import ToolAdapter
1213from unstract .sdk .constants import LogLevel
13- from unstract .sdk .exceptions import SdkError
14+ from unstract .sdk .exceptions import RateLimitError , SdkError , ToolLLMError
1415from unstract .sdk .tool .base import BaseTool
15- from unstract .sdk .utils .callback_manager import (
16- CallbackManager as UNCallbackManager ,
17- )
16+ from unstract .sdk .utils .callback_manager import CallbackManager as UNCallbackManager
1817
1918logger = logging .getLogger (__name__ )
2019
2120
2221class ToolLLM :
2322 """Class to handle LLMs for Unstract Tools."""
2423
25- json_regex = re .compile (r"\{(?:.|\n)*\}" )
24+ json_regex = re .compile (r"\[(?:.|\n)*\]|\ {(?:.|\n)*\}" )
2625
2726 def __init__ (self , tool : BaseTool ):
2827 """ToolLLM constructor.
@@ -62,21 +61,21 @@ def run_completion(
6261 "run_id" ,
6362 ]:
6463 new_kwargs .pop (key , None )
65- for i in range ( retries ):
66- try :
67- response : CompletionResponse = llm .complete (
68- prompt , ** new_kwargs
69- )
70- match = cls . json_regex . search ( response . text )
71- if match :
72- response . text = match . group ( 0 )
73- return { "response" : response }
74-
75- except Exception as e :
76- if i == retries - 1 :
77- raise e
78- time . sleep ( 5 )
79- return None
64+
65+ try :
66+ response : CompletionResponse = llm .complete (prompt , ** new_kwargs )
67+ match = cls . json_regex . search ( response . text )
68+ if match :
69+ response . text = match . group ( 0 )
70+ return { "response" : response }
71+ # TODO: Handle for all LLM providers
72+ except OpenAIAPIError as e :
73+ msg = e . message
74+ if hasattr ( e , "body" ) and "message" in e . body :
75+ msg = e . body [ "message" ]
76+ if isinstance ( e , OpenAIRateLimitError ):
77+ raise RateLimitError ( msg )
78+ raise ToolLLMError ( msg ) from e
8079
8180 def get_llm (self , adapter_instance_id : str ) -> LLM :
8281 """Returns the LLM object for the tool.
@@ -91,13 +90,11 @@ def get_llm(self, adapter_instance_id: str) -> LLM:
9190 )
9291 llm_adapter_id = llm_config_data .get (Common .ADAPTER_ID )
9392 if llm_adapter_id not in self .llm_adapters :
94- raise SdkError (
95- f"LLM adapter not supported : " f"{ llm_adapter_id } "
96- )
93+ raise SdkError (f"LLM adapter not supported : " f"{ llm_adapter_id } " )
9794
98- llm_adapter = self .llm_adapters [llm_adapter_id ][
99- Common .METADATA
100- ][ Common . ADAPTER ]
95+ llm_adapter = self .llm_adapters [llm_adapter_id ][Common . METADATA ][
96+ Common .ADAPTER
97+ ]
10198 llm_metadata = llm_config_data .get (Common .ADAPTER_METADATA )
10299 llm_adapter_class : LLMAdapter = llm_adapter (llm_metadata )
103100 llm_instance : LLM = llm_adapter_class .get_llm_instance ()
@@ -106,7 +103,7 @@ def get_llm(self, adapter_instance_id: str) -> LLM:
106103 self .tool .stream_log (
107104 log = f"Unable to get llm instance: { e } " , level = LogLevel .ERROR
108105 )
109- raise SdkError (f"Error getting llm instance: { e } " )
106+ raise ToolLLMError (f"Error getting llm instance: { e } " )
110107
111108 def get_max_tokens (self , reserved_for_output : int = 0 ) -> int :
112109 """Returns the maximum number of tokens that can be used for the LLM.
0 commit comments