Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions src/sentry/llm/providers/vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import google.auth
import google.auth.transport.requests
import requests
from google import genai
from google.genai.types import GenerateContentConfig, HttpOptions

from sentry.llm.exceptions import VertexRequestFailed
from sentry.llm.providers.base import LlmModelBase
Expand Down Expand Up @@ -30,35 +31,39 @@ def _complete_prompt(
max_output_tokens: int,
) -> str | None:

model = usecase_config["options"]["model"]
content = f"{prompt} {message}" if prompt else message
generate_config = GenerateContentConfig(
candidate_count=self.candidate_count,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_p=self.top_p,
)

payload = {
"instances": [{"content": content}],
"parameters": {
"candidateCount": self.candidate_count,
"maxOutputTokens": max_output_tokens,
"temperature": temperature,
"topP": self.top_p,
},
}

headers = {
"Authorization": f"Bearer {self._get_access_token()}",
"Content-Type": "application/json",
}
vertex_url = self.provider_config["options"]["url"]
vertex_url += usecase_config["options"]["model"] + ":predict"

response = requests.post(vertex_url, headers=headers, json=payload)
client = self._create_genai_client()
response = client.models.generate_content(
model=model,
contents=content,
config=generate_config,
)

if response.status_code != 200:
logger.error(
"Request failed with status code and response text.",
extra={"status_code": response.status_code, "response_text": response.text},
"Vertex request failed.",
extra={"status_code": response.status_code},
)
raise VertexRequestFailed(f"Response {response.status_code}: {response.text}")
raise VertexRequestFailed(f"Response {response.status_code}")

return response.text

return response.json()["predictions"][0]["content"]
# Separate method to allow mocking
def _create_genai_client(self):
return genai.Client(
vertexai=True,
project=self.provider_config["options"]["gcp_project"],
location=self.provider_config["options"]["gcp_location"],
http_options=HttpOptions(api_version="v1"),
)

def _get_access_token(self) -> str:
# https://stackoverflow.com/questions/53472429/how-to-get-a-gcp-bearer-token-programmatically-with-python
Expand Down
83 changes: 66 additions & 17 deletions tests/sentry/llm/test_vertex.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,53 @@
from unittest.mock import patch
from unittest.mock import Mock, patch

from sentry.llm.usecases import LLMUseCase, complete_prompt
import pytest

from sentry.llm.exceptions import VertexRequestFailed
from sentry.llm.usecases import LLMUseCase, complete_prompt, llm_provider_backends

def test_complete_prompt(set_sentry_option):

@pytest.fixture
def mock_options(set_sentry_option):
with (
set_sentry_option(
"llm.provider.options",
{"vertex": {"models": ["vertex-1.0"], "options": {"url": "fake_url"}}},
{
"vertex": {
"models": ["vertex-1.0"],
"options": {"gcp_project": "my-gcp-project", "gcp_location": "us-central1"},
}
},
),
set_sentry_option(
"llm.usecases.options",
{"example": {"provider": "vertex", "options": {"model": "vertex-1.0"}}},
),
patch(
"sentry.llm.providers.vertex.VertexProvider._get_access_token",
return_value="fake_token",
),
patch(
"requests.post",
return_value=type(
"obj",
(object,),
{"status_code": 200, "json": lambda x: {"predictions": [{"content": ""}]}},
)(),
),
):
yield


class MockGenaiClient:
def __init__(self, mock_generate_content):
self.models = type(
"obj",
(object,),
{"generate_content": mock_generate_content},
)()


def test_complete_prompt(mock_options):
llm_provider_backends.clear()
mock_generate_content = Mock(
return_value=type(
"obj",
(object,),
{"status_code": 200, "text": "hello world"},
)()
)

with patch(
"sentry.llm.providers.vertex.VertexProvider._create_genai_client",
return_value=MockGenaiClient(mock_generate_content),
):
res = complete_prompt(
usecase=LLMUseCase.EXAMPLE,
Expand All @@ -33,4 +56,30 @@ def test_complete_prompt(set_sentry_option):
temperature=0.0,
max_output_tokens=1024,
)
assert res == ""

assert res == "hello world"
assert mock_generate_content.call_count == 1
assert mock_generate_content.call_args[1]["model"] == "vertex-1.0"


def test_complete_prompt_error(mock_options):
llm_provider_backends.clear()
mock_generate_content = Mock(
return_value=type(
"obj",
(object,),
{"status_code": 400},
)()
)

with patch(
"sentry.llm.providers.vertex.VertexProvider._create_genai_client",
return_value=MockGenaiClient(mock_generate_content),
):
with pytest.raises(VertexRequestFailed):
complete_prompt(
usecase=LLMUseCase.EXAMPLE,
message="message here",
temperature=0.0,
max_output_tokens=1024,
)
Loading