Skip to content

Commit 4672bcd

Browse files
authored
Add vertex AI (#85)
1 parent 70a6024 commit 4672bcd

File tree

16 files changed

+754
-105
lines changed

16 files changed

+754
-105
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010

1111
env:
1212
CI: true
13-
RICH_COLUMNS: 120
13+
COLUMNS: 120
1414

1515
permissions:
1616
contents: read
@@ -80,11 +80,12 @@ jobs:
8080
with:
8181
enable-cache: true
8282

83-
- run: uv run --python 3.12 --frozen pytest tests/test_live.py -v --durations=100
83+
- run: uv run --python 3.12 --frozen --extra vertexai pytest tests/test_live.py -v --durations=100
8484
env:
8585
PYDANTIC_AI_LIVE_TEST_DANGEROUS: 'CHARGE-ME!'
8686
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
8787
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
88+
GOOGLE_SERVICE_ACCOUNT_CONTENT: ${{ secrets.GOOGLE_SERVICE_ACCOUNT_CONTENT }}
8889

8990
test:
9091
name: test on ${{ matrix.python-version }}

docs/api/models/vertexai.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# `pydantic_ai.models.vertexai`
2+
3+
::: pydantic_ai.models.vertexai

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ nav:
3838
- api/models/base.md
3939
- api/models/openai.md
4040
- api/models/gemini.md
41+
- api/models/vertexai.md
4142
- api/models/test.md
4243
- api/models/function.md
4344

pydantic_ai/agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,8 @@ async def _get_agent_model(
601601
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
602602

603603
result_tools = list(self._result_schema.tools.values()) if self._result_schema else None
604-
return model_, custom_model, model_.agent_model(self._retrievers, self._allow_text_result, result_tools)
604+
agent_model = await model_.agent_model(self._retrievers, self._allow_text_result, result_tools)
605+
return model_, custom_model, agent_model
605606

606607
async def _prepare_messages(
607608
self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.Message] | None

pydantic_ai/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,16 @@ class Model(ABC):
4444
"""Abstract class for a model."""
4545

4646
@abstractmethod
47-
def agent_model(
47+
async def agent_model(
4848
self,
4949
retrievers: Mapping[str, AbstractToolDefinition],
5050
allow_text_result: bool,
5151
result_tools: Sequence[AbstractToolDefinition] | None,
5252
) -> AgentModel:
5353
"""Create an agent model.
5454
55+
This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
56+
5557
Args:
5658
retrievers: The retrievers available to the agent.
5759
allow_text_result: Whether a plain text final response/result is permitted.

pydantic_ai/models/function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, function: FunctionDef | None = None, *, stream_function: Stre
6565
self.function = function
6666
self.stream_function = stream_function
6767

68-
def agent_model(
68+
async def agent_model(
6969
self,
7070
retrievers: Mapping[str, AbstractToolDefinition],
7171
allow_text_result: bool,

pydantic_ai/models/gemini.py

Lines changed: 71 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from copy import deepcopy
3030
from dataclasses import dataclass, field
3131
from datetime import datetime
32-
from typing import Annotated, Any, Literal, Union
32+
from typing import Annotated, Any, Literal, Protocol, Union
3333

3434
import pydantic_core
3535
from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse
@@ -77,17 +77,17 @@ class GeminiModel(Model):
7777
"""
7878

7979
model_name: GeminiModelName
80-
api_key: str
80+
auth: AuthProtocol
8181
http_client: AsyncHTTPClient
82-
url_template: str
82+
url: str
8383

8484
def __init__(
8585
self,
8686
model_name: GeminiModelName,
8787
*,
8888
api_key: str | None = None,
8989
http_client: AsyncHTTPClient | None = None,
90-
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:{function}',
90+
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
9191
):
9292
"""Initialize a Gemini model.
9393
@@ -97,62 +97,94 @@ def __init__(
9797
will be used if available.
9898
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
9999
url_template: The URL template to use for making requests, you shouldn't need to change this,
100-
docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request).
100+
docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
101+
`model` is substituted with the model name, and `function` is added to the end of the URL.
101102
"""
102103
self.model_name = model_name
103104
if api_key is None:
104105
if env_api_key := os.getenv('GEMINI_API_KEY'):
105106
api_key = env_api_key
106107
else:
107108
raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
108-
self.api_key = api_key
109+
self.auth = ApiKeyAuth(api_key)
109110
self.http_client = http_client or cached_async_http_client()
110-
self.url_template = url_template
111+
self.url = url_template.format(model=model_name)
111112

112-
def agent_model(
113+
async def agent_model(
113114
self,
114115
retrievers: Mapping[str, AbstractToolDefinition],
115116
allow_text_result: bool,
116117
result_tools: Sequence[AbstractToolDefinition] | None,
117118
) -> GeminiAgentModel:
118-
check_allow_model_requests()
119-
tools = [_function_from_abstract_tool(t) for t in retrievers.values()]
120-
if result_tools is not None:
121-
tools += [_function_from_abstract_tool(t) for t in result_tools]
122-
123-
if allow_text_result:
124-
tool_config = None
125-
else:
126-
tool_config = _tool_config([t['name'] for t in tools])
127-
128119
return GeminiAgentModel(
129120
http_client=self.http_client,
130121
model_name=self.model_name,
131-
api_key=self.api_key,
132-
tools=_GeminiTools(function_declarations=tools) if tools else None,
133-
tool_config=tool_config,
134-
url_template=self.url_template,
122+
auth=self.auth,
123+
url=self.url,
124+
retrievers=retrievers,
125+
allow_text_result=allow_text_result,
126+
result_tools=result_tools,
135127
)
136128

137129
def name(self) -> str:
138130
return self.model_name
139131

140132

133+
class AuthProtocol(Protocol):
134+
async def headers(self) -> dict[str, str]: ...
135+
136+
141137
@dataclass
138+
class ApiKeyAuth:
139+
api_key: str
140+
141+
async def headers(self) -> dict[str, str]:
142+
# https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
143+
return {'X-Goog-Api-Key': self.api_key}
144+
145+
146+
@dataclass(init=False)
142147
class GeminiAgentModel(AgentModel):
143148
"""Implementation of `AgentModel` for Gemini models."""
144149

145150
http_client: AsyncHTTPClient
146151
model_name: GeminiModelName
147-
api_key: str
152+
auth: AuthProtocol
148153
tools: _GeminiTools | None
149154
tool_config: _GeminiToolConfig | None
150-
url_template: str
155+
url: str
156+
157+
def __init__(
158+
self,
159+
http_client: AsyncHTTPClient,
160+
model_name: GeminiModelName,
161+
auth: AuthProtocol,
162+
url: str,
163+
retrievers: Mapping[str, AbstractToolDefinition],
164+
allow_text_result: bool,
165+
result_tools: Sequence[AbstractToolDefinition] | None,
166+
):
167+
check_allow_model_requests()
168+
tools = [_function_from_abstract_tool(t) for t in retrievers.values()]
169+
if result_tools is not None:
170+
tools += [_function_from_abstract_tool(t) for t in result_tools]
171+
172+
if allow_text_result:
173+
tool_config = None
174+
else:
175+
tool_config = _tool_config([t['name'] for t in tools])
176+
177+
self.http_client = http_client
178+
self.model_name = model_name
179+
self.auth = auth
180+
self.tools = _GeminiTools(function_declarations=tools) if tools else None
181+
self.tool_config = tool_config
182+
self.url = url
151183

152184
async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
153185
async with self._make_request(messages, False) as http_response:
154186
response = _gemini_response_ta.validate_json(await http_response.aread())
155-
return self._process_response(response), _metadata_as_cost(response['usage_metadata'])
187+
return self._process_response(response), _metadata_as_cost(response)
156188

157189
@asynccontextmanager
158190
async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
@@ -178,16 +210,15 @@ async def _make_request(self, messages: list[Message], streamed: bool) -> AsyncI
178210
if self.tool_config is not None:
179211
request_data['tool_config'] = self.tool_config
180212

181-
request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
182-
# https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
213+
url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
214+
183215
headers = {
184-
'X-Goog-Api-Key': self.api_key,
185216
'Content-Type': 'application/json',
186217
'User-Agent': get_user_agent(),
218+
**await self.auth.headers(),
187219
}
188-
url = self.url_template.format(
189-
model=self.model_name, function='streamGenerateContent' if streamed else 'generateContent'
190-
)
220+
221+
request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
191222

192223
async with self.http_client.stream('POST', url, content=request_json, headers=headers) as r:
193224
if r.status_code != 200:
@@ -283,7 +314,7 @@ def get(self, *, final: bool = False) -> Iterable[str]:
283314
new_items, experimental_allow_partial='trailing-strings'
284315
)
285316
for r in new_responses:
286-
self._cost += _metadata_as_cost(r['usage_metadata'])
317+
self._cost += _metadata_as_cost(r)
287318
parts = r['candidates'][0]['content']['parts']
288319
if _all_text_parts(parts):
289320
for part in parts:
@@ -329,7 +360,7 @@ def get(self, *, final: bool = False) -> ModelStructuredResponse:
329360
combined_parts: list[_GeminiFunctionCallPart] = []
330361
self._cost = result.Cost()
331362
for r in responses:
332-
self._cost += _metadata_as_cost(r['usage_metadata'])
363+
self._cost += _metadata_as_cost(r)
333364
candidate = r['candidates'][0]
334365
parts = candidate['content']['parts']
335366
if _all_function_call_parts(parts):
@@ -521,10 +552,12 @@ class _GeminiResponse(TypedDict):
521552
"""Schema for the response from the Gemini API.
522553
523554
See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>
555+
and <https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerateContentResponse>
524556
"""
525557

526558
candidates: list[_GeminiCandidates]
527-
usage_metadata: Annotated[_GeminiUsageMetaData, Field(alias='usageMetadata')]
559+
# usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
560+
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, Field(alias='usageMetadata')]]
528561
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, Field(alias='promptFeedback')]]
529562

530563

@@ -582,7 +615,10 @@ class _GeminiUsageMetaData(TypedDict, total=False):
582615
cached_content_token_count: NotRequired[Annotated[int, Field(alias='cachedContentTokenCount')]]
583616

584617

585-
def _metadata_as_cost(metadata: _GeminiUsageMetaData) -> result.Cost:
618+
def _metadata_as_cost(response: _GeminiResponse) -> result.Cost:
619+
metadata = response.get('usage_metadata')
620+
if metadata is None:
621+
return result.Cost()
586622
details: dict[str, int] = {}
587623
if cached_content_token_count := metadata.get('cached_content_token_count'):
588624
details['cached_content_token_count'] = cached_content_token_count

pydantic_ai/models/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(
8080
else:
8181
self.client = AsyncOpenAI(api_key=api_key, http_client=cached_async_http_client())
8282

83-
def agent_model(
83+
async def agent_model(
8484
self,
8585
retrievers: Mapping[str, AbstractToolDefinition],
8686
allow_text_result: bool,

pydantic_ai/models/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class TestModel(Model):
7070
agent_model_allow_text_result: bool | None = field(default=None, init=False)
7171
agent_model_result_tools: list[AbstractToolDefinition] | None = field(default=None, init=False)
7272

73-
def agent_model(
73+
async def agent_model(
7474
self,
7575
retrievers: Mapping[str, AbstractToolDefinition],
7676
allow_text_result: bool,

0 commit comments

Comments
 (0)