diff --git a/src/sentry/llm/providers/vertex.py b/src/sentry/llm/providers/vertex.py index 0cccb41af7fd06..73733c8a57aeea 100644 --- a/src/sentry/llm/providers/vertex.py +++ b/src/sentry/llm/providers/vertex.py @@ -2,7 +2,8 @@ import google.auth import google.auth.transport.requests -import requests +from google import genai +from google.genai.types import GenerateContentConfig, HttpOptions from sentry.llm.exceptions import VertexRequestFailed from sentry.llm.providers.base import LlmModelBase @@ -30,35 +31,39 @@ def _complete_prompt( max_output_tokens: int, ) -> str | None: + model = usecase_config["options"]["model"] content = f"{prompt} {message}" if prompt else message + generate_config = GenerateContentConfig( + candidate_count=self.candidate_count, + max_output_tokens=max_output_tokens, + temperature=temperature, + top_p=self.top_p, + ) - payload = { - "instances": [{"content": content}], - "parameters": { - "candidateCount": self.candidate_count, - "maxOutputTokens": max_output_tokens, - "temperature": temperature, - "topP": self.top_p, - }, - } - - headers = { - "Authorization": f"Bearer {self._get_access_token()}", - "Content-Type": "application/json", - } - vertex_url = self.provider_config["options"]["url"] - vertex_url += usecase_config["options"]["model"] + ":predict" - - response = requests.post(vertex_url, headers=headers, json=payload) + client = self._create_genai_client() + response = client.models.generate_content( + model=model, + contents=content, + config=generate_config, + ) if response.status_code != 200: logger.error( - "Request failed with status code and response text.", - extra={"status_code": response.status_code, "response_text": response.text}, + "Vertex request failed.", + extra={"status_code": response.status_code}, ) - raise VertexRequestFailed(f"Response {response.status_code}: {response.text}") + raise VertexRequestFailed(f"Response {response.status_code}") + + return response.text - return response.json()["predictions"][0]["content"] + # Separate method to allow mocking + def _create_genai_client(self): + return genai.Client( + vertexai=True, + project=self.provider_config["options"]["gcp_project"], + location=self.provider_config["options"]["gcp_location"], + http_options=HttpOptions(api_version="v1"), + ) def _get_access_token(self) -> str: # https://stackoverflow.com/questions/53472429/how-to-get-a-gcp-bearer-token-programmatically-with-python diff --git a/tests/sentry/llm/test_vertex.py b/tests/sentry/llm/test_vertex.py index 09f43cf85e15ec..133ef06a980e00 100644 --- a/tests/sentry/llm/test_vertex.py +++ b/tests/sentry/llm/test_vertex.py @@ -1,30 +1,53 @@ -from unittest.mock import patch +from unittest.mock import Mock, patch -from sentry.llm.usecases import LLMUseCase, complete_prompt +import pytest +from sentry.llm.exceptions import VertexRequestFailed +from sentry.llm.usecases import LLMUseCase, complete_prompt, llm_provider_backends -def test_complete_prompt(set_sentry_option): + +@pytest.fixture +def mock_options(set_sentry_option): with ( set_sentry_option( "llm.provider.options", - {"vertex": {"models": ["vertex-1.0"], "options": {"url": "fake_url"}}}, + { + "vertex": { + "models": ["vertex-1.0"], + "options": {"gcp_project": "my-gcp-project", "gcp_location": "us-central1"}, + } + }, ), set_sentry_option( "llm.usecases.options", {"example": {"provider": "vertex", "options": {"model": "vertex-1.0"}}}, ), - patch( - "sentry.llm.providers.vertex.VertexProvider._get_access_token", - return_value="fake_token", - ), - patch( - "requests.post", - return_value=type( - "obj", - (object,), - {"status_code": 200, "json": lambda x: {"predictions": [{"content": ""}]}}, - )(), - ), + ): + yield + + +class MockGenaiClient: + def __init__(self, mock_generate_content): + self.models = type( + "obj", + (object,), + {"generate_content": mock_generate_content}, + )() + + +def test_complete_prompt(mock_options): + llm_provider_backends.clear() + mock_generate_content = Mock( + return_value=type( + "obj", + (object,), + {"status_code": 200, "text": "hello world"}, + )() + ) + + with patch( + "sentry.llm.providers.vertex.VertexProvider._create_genai_client", + return_value=MockGenaiClient(mock_generate_content), ): res = complete_prompt( usecase=LLMUseCase.EXAMPLE, @@ -33,4 +56,30 @@ def test_complete_prompt(set_sentry_option): temperature=0.0, max_output_tokens=1024, ) - assert res == "" + + assert res == "hello world" + assert mock_generate_content.call_count == 1 + assert mock_generate_content.call_args[1]["model"] == "vertex-1.0" + + +def test_complete_prompt_error(mock_options): + llm_provider_backends.clear() + mock_generate_content = Mock( + return_value=type( + "obj", + (object,), + {"status_code": 400}, + )() + ) + + with patch( + "sentry.llm.providers.vertex.VertexProvider._create_genai_client", + return_value=MockGenaiClient(mock_generate_content), + ): + with pytest.raises(VertexRequestFailed): + complete_prompt( + usecase=LLMUseCase.EXAMPLE, + message="message here", + temperature=0.0, + max_output_tokens=1024, + )