|
17 | 17 | from canopy.models.data_models import Messages, Query |
18 | 18 |
|
19 | 19 |
|
| 20 | +def _format_openai_error(e): |
| 21 | + try: |
| 22 | + return e.response.json()['error']['message'] |
| 23 | + except Exception: |
| 24 | + return str(e) |
| 25 | + |
| 26 | + |
20 | 27 | class OpenAILLM(BaseLLM): |
21 | 28 | """ |
22 | 29 | OpenAI LLM wrapper built on top of the OpenAI Python client. |
@@ -48,9 +55,17 @@ def __init__(self, |
48 | 55 | These params can be overridden by passing a `model_params` argument to the `chat_completion` or `enforced_function_call` methods. |
49 | 56 | """ # noqa: E501 |
50 | 57 | super().__init__(model_name) |
51 | | - self._client = openai.OpenAI(api_key=api_key, |
52 | | - organization=organization, |
53 | | - base_url=base_url) |
| 58 | + try: |
| 59 | + self._client = openai.OpenAI(api_key=api_key, |
| 60 | + organization=organization, |
| 61 | + base_url=base_url) |
| 62 | + except openai.OpenAIError as e: |
| 63 | + raise RuntimeError( |
| 64 | + "Failed to connect to OpenAI, please make sure that the OPENAI_API_KEY " |
| 65 | + "environment variable is set correctly.\n" |
| 66 | + f"Error: {_format_openai_error(e)}" |
| 67 | + ) |
| 68 | + |
54 | 69 | self.default_model_params = kwargs |
55 | 70 |
|
56 | 71 | @property |
@@ -96,11 +111,19 @@ def chat_completion(self, |
96 | 111 | ) |
97 | 112 |
|
98 | 113 | messages = [m.dict() for m in messages] |
99 | | - response = self._client.chat.completions.create(model=self.model_name, |
100 | | - messages=messages, |
101 | | - stream=stream, |
102 | | - max_tokens=max_tokens, |
103 | | - **model_params_dict) |
| 114 | + try: |
| 115 | + response = self._client.chat.completions.create(model=self.model_name, |
| 116 | + messages=messages, |
| 117 | + stream=stream, |
| 118 | + max_tokens=max_tokens, |
| 119 | + **model_params_dict) |
| 120 | + except openai.OpenAIError as e: |
| 121 | + provider_name = self.__class__.__name__.replace("LLM", "") |
| 122 | + raise RuntimeError( |
| 123 | + f"Failed to use {provider_name}'s {self.model_name} model for chat " |
| 124 | + f"completion.\n" |
| 125 | + f"Error: {_format_openai_error(e)}" |
| 126 | + ) |
104 | 127 |
|
105 | 128 | def streaming_iterator(response): |
106 | 129 | for chunk in response: |
@@ -175,15 +198,23 @@ def enforced_function_call(self, |
175 | 198 | function_dict = cast(ChatCompletionToolParam, |
176 | 199 | {"type": "function", "function": function.dict()}) |
177 | 200 |
|
178 | | - chat_completion = self._client.chat.completions.create( |
179 | | - messages=[m.dict() for m in messages], |
180 | | - model=self.model_name, |
181 | | - tools=[function_dict], |
182 | | - tool_choice={"type": "function", |
183 | | - "function": {"name": function.name}}, |
184 | | - max_tokens=max_tokens, |
185 | | - **model_params_dict |
186 | | - ) |
| 201 | + try: |
| 202 | + chat_completion = self._client.chat.completions.create( |
| 203 | + messages=[m.dict() for m in messages], |
| 204 | + model=self.model_name, |
| 205 | + tools=[function_dict], |
| 206 | + tool_choice={"type": "function", |
| 207 | + "function": {"name": function.name}}, |
| 208 | + max_tokens=max_tokens, |
| 209 | + **model_params_dict |
| 210 | + ) |
| 211 | + except openai.OpenAIError as e: |
| 212 | + provider_name = self.__class__.__name__.replace("LLM", "") |
| 213 | + raise RuntimeError( |
| 214 | + f"Failed to use {provider_name}'s {self.model_name} model for " |
| 215 | + f"chat completion with enforced function calling.\n" |
| 216 | + f"Error: {_format_openai_error(e)}" |
| 217 | + ) |
187 | 218 |
|
188 | 219 | result = chat_completion.choices[0].message.tool_calls[0].function.arguments |
189 | 220 | arguments = json.loads(result) |
|
0 commit comments