Skip to content

Commit 352da43

Browse files
authored
Recreate access token on 401 for Google Vertex provider (#1195)
1 parent cc8ff44 commit 352da43

File tree

2 files changed

+16
-23
lines changed

2 files changed

+16
-23
lines changed

pydantic_ai_slim/pydantic_ai/providers/google_vertex.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import functools
44
from collections.abc import AsyncGenerator, Mapping
5-
from datetime import datetime, timedelta
65
from pathlib import Path
76
from typing import Literal, overload
87

@@ -28,9 +27,6 @@
2827

2928
__all__ = ('GoogleVertexProvider',)
3029

31-
# default expiry is 3600 seconds
32-
MAX_TOKEN_AGE = timedelta(seconds=3000)
33-
3430

3531
class GoogleVertexProvider(Provider[httpx.AsyncClient]):
3632
"""Provider for Vertex AI API."""
@@ -131,19 +127,21 @@ def __init__(
131127
self.region = region
132128

133129
self.credentials = None
134-
self.token_created: datetime | None = None
135130

136131
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
137132
if self.credentials is None:
138133
self.credentials = await self._get_credentials()
139-
if self.credentials.token is None or self._token_expired(): # type: ignore[reportUnknownMemberType]
140-
await anyio.to_thread.run_sync(self._refresh_token)
141-
self.token_created = datetime.now()
134+
if self.credentials.token is None: # type: ignore[reportUnknownMemberType]
135+
await self._refresh_token()
142136
request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType]
143-
144137
# NOTE: This workaround is in place because we might get the project_id from the credentials.
145138
request.url = httpx.URL(str(request.url).replace('projects/None', f'projects/{self.project_id}'))
146-
yield request
139+
response = yield request
140+
141+
if response.status_code == 401:
142+
await self._refresh_token()
143+
request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType]
144+
yield request
147145

148146
async def _get_credentials(self) -> BaseCredentials | ServiceAccountCredentials:
149147
if self.service_account_file is not None:
@@ -166,15 +164,9 @@ async def _get_credentials(self) -> BaseCredentials | ServiceAccountCredentials:
166164
self.project_id = creds_project_id
167165
return creds
168166

169-
def _token_expired(self) -> bool:
170-
if self.token_created is None:
171-
return True
172-
else:
173-
return (datetime.now() - self.token_created) > MAX_TOKEN_AGE
174-
175-
def _refresh_token(self) -> str: # pragma: no cover
167+
async def _refresh_token(self) -> str: # pragma: no cover
176168
assert self.credentials is not None
177-
self.credentials.refresh(Request()) # type: ignore[reportUnknownMemberType]
169+
await anyio.to_thread.run_sync(self.credentials.refresh, Request()) # type: ignore[reportUnknownMemberType]
178170
assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType]
179171
return self.credentials.token
180172

tests/providers/test_google_vertex.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ async def test_google_vertex_provider_auth(allow_model_requests: None, http_clie
5757
await provider.client.post('/gemini-1.0-pro:generateContent')
5858
assert provider.region == 'us-central1'
5959
assert getattr(provider.client.auth, 'project_id') == 'my-project-id'
60-
assert getattr(provider.client.auth, 'token_created') is not None
60+
61+
62+
async def mock_refresh_token():
63+
return 'my-token'
6164

6265

6366
async def test_google_vertex_provider_service_account_file(
@@ -67,11 +70,10 @@ async def test_google_vertex_provider_service_account_file(
6770
save_service_account(service_account_path, 'my-project-id')
6871

6972
provider = GoogleVertexProvider(service_account_file=service_account_path)
70-
monkeypatch.setattr(provider.client.auth, '_refresh_token', lambda: 'my-token')
73+
monkeypatch.setattr(provider.client.auth, '_refresh_token', mock_refresh_token)
7174
await provider.client.post('/gemini-1.0-pro:generateContent')
7275
assert provider.region == 'us-central1'
7376
assert getattr(provider.client.auth, 'project_id') == 'my-project-id'
74-
assert getattr(provider.client.auth, 'token_created') is not None
7577

7678

7779
async def test_google_vertex_provider_service_account_file_info(
@@ -80,11 +82,10 @@ async def test_google_vertex_provider_service_account_file_info(
8082
account_info = prepare_service_account_contents('my-project-id')
8183

8284
provider = GoogleVertexProvider(service_account_info=account_info)
83-
monkeypatch.setattr(provider.client.auth, '_refresh_token', lambda: 'my-token')
85+
monkeypatch.setattr(provider.client.auth, '_refresh_token', mock_refresh_token)
8486
await provider.client.post('/gemini-1.0-pro:generateContent')
8587
assert provider.region == 'us-central1'
8688
assert getattr(provider.client.auth, 'project_id') == 'my-project-id'
87-
assert getattr(provider.client.auth, 'token_created') is not None
8889

8990

9091
async def test_google_vertex_provider_service_account_xor(allow_model_requests: None):

0 commit comments

Comments
 (0)