|
7 | 7 | from google.oauth2.service_account import Credentials |
8 | 8 | from llama_index.core.llms import LLM |
9 | 9 | from llama_index.llms.vertex import Vertex |
| 10 | +from vertexai.generative_models import Candidate, FinishReason, ResponseValidationError |
10 | 11 | from vertexai.generative_models._generative_models import ( |
11 | 12 | HarmBlockThreshold, |
12 | 13 | HarmCategory, |
@@ -191,3 +192,62 @@ def test_connection(self) -> bool: |
191 | 192 | except Exception as e: |
192 | 193 | raise LLMError(f"Error while testing connection for VertexAI: {str(e)}") |
193 | 194 | return test_result |
| 195 | + |
| 196 | + @staticmethod |
| 197 | + def parse_llm_err(e: ResponseValidationError) -> LLMError: |
| 198 | + """Parse the error from Vertex AI. |
| 199 | +
|
| 200 | + Helps parse and raise errors from Vertex AI. |
| 201 | + https://ai.google.dev/api/generate-content#generatecontentresponse |
| 202 | +
|
| 203 | + Args: |
| 204 | + e (ResponseValidationError): Exception from Vertex AI |
| 205 | +
|
| 206 | + Returns: |
| 207 | + LLMError: Error to be sent to the user |
| 208 | + """ |
| 209 | + assert len(e.responses) == 1, ( |
| 210 | + "Expected e.responses to contain a single element " |
| 211 | + "since its a completion call and not chat." |
| 212 | + ) |
| 213 | + resp = e.responses[0] |
| 214 | + candidates: list["Candidate"] = resp.candidates |
| 215 | + if not candidates: |
| 216 | + msg = str(resp.prompt_feedback) |
| 217 | + reason_messages = { |
| 218 | + FinishReason.MAX_TOKENS: ( |
| 219 | + "The maximum number of tokens for the LLM has been reached. Please " |
| 220 | + "either tweak your prompts or try using another LLM." |
| 221 | + ), |
| 222 | + FinishReason.STOP: ( |
| 223 | + "The LLM stopped generating a response due to the natural stop " |
| 224 | + "point of the model or a provided stop sequence." |
| 225 | + ), |
| 226 | + FinishReason.SAFETY: "The LLM response was flagged for safety reasons.", |
| 227 | + FinishReason.RECITATION: ( |
| 228 | + "The LLM response was flagged for recitation reasons." |
| 229 | + ), |
| 230 | + FinishReason.BLOCKLIST: ( |
| 231 | + "The LLM response generation was stopped because it " |
| 232 | + "contains forbidden terms." |
| 233 | + ), |
| 234 | + FinishReason.PROHIBITED_CONTENT: ( |
| 235 | + "The LLM response generation was stopped because it " |
| 236 | + "potentially contains prohibited content." |
| 237 | + ), |
| 238 | + FinishReason.SPII: ( |
| 239 | + "The LLM response generation was stopped because it potentially " |
| 240 | + "contains Sensitive Personally Identifiable Information." |
| 241 | + ), |
| 242 | + } |
| 243 | + |
| 244 | + err_list = [] |
| 245 | + for candidate in candidates: |
| 246 | + reason: FinishReason = candidate.finish_reason |
| 247 | + if candidate.finish_message: |
| 248 | + err_msg = candidate.finish_message |
| 249 | + else: |
| 250 | + err_msg = reason_messages.get(reason, str(candidate)) |
| 251 | + err_list.append(err_msg) |
| 252 | + msg = "\n\nAnother error: \n".join(err_list) |
| 253 | + return LLMError(msg) |
0 commit comments