Skip to content

Commit 66fa21a

Browse files
authored
Include project ID for Google service account auth (#2755)
1 parent c63dbf6 commit 66fa21a

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

docs/models/google.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ credentials = service_account.Credentials.from_service_account_file(
8787
'path/to/service-account.json',
8888
scopes=['https://www.googleapis.com/auth/cloud-platform'],
8989
)
90-
provider = GoogleProvider(credentials=credentials)
90+
provider = GoogleProvider(credentials=credentials, project='your-project-id')
9191
model = GoogleModel('gemini-1.5-flash', provider=provider)
9292
agent = Agent(model)
9393
...

tests/test_live.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
WARNING: running these tests will make use of the relevant API tokens (and cost money).
44
"""
55

6+
import json
67
import os
78
from collections.abc import AsyncIterator, Callable
89
from pathlib import Path
@@ -40,14 +41,15 @@ def vertexai(_: httpx.AsyncClient, tmp_path: Path) -> Model:
4041
from pydantic_ai.providers.google import GoogleProvider
4142

4243
service_account_content = os.environ['GOOGLE_SERVICE_ACCOUNT_CONTENT']
44+
project_id = json.loads(service_account_content)['project_id']
4345
service_account_path = tmp_path / 'service_account.json'
4446
service_account_path.write_text(service_account_content)
4547

4648
credentials = service_account.Credentials.from_service_account_file( # type: ignore[reportUnknownReturnType]
4749
service_account_path,
4850
scopes=['https://www.googleapis.com/auth/cloud-platform'],
4951
)
50-
provider = GoogleProvider(credentials=credentials)
52+
provider = GoogleProvider(credentials=credentials, project=project_id)
5153
return GoogleModel('gemini-1.5-flash', provider=provider)
5254

5355

@@ -91,9 +93,7 @@ def cohere(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
9193
params = [
9294
pytest.param(openai, id='openai'),
9395
pytest.param(gemini, marks=pytest.mark.skip(reason='API seems very flaky'), id='gemini'),
94-
pytest.param(
95-
vertexai, marks=pytest.mark.skip(reason='This needs to be fixed. It raises RuntimeError.'), id='vertexai'
96-
),
96+
pytest.param(vertexai, id='vertexai'),
9797
pytest.param(groq, id='groq'),
9898
pytest.param(anthropic, id='anthropic'),
9999
pytest.param(ollama, id='ollama'),

0 commit comments

Comments
 (0)