2
2
3
3
import functools
4
4
from collections .abc import AsyncGenerator , Mapping
5
- from datetime import datetime , timedelta
6
5
from pathlib import Path
7
6
from typing import Literal , overload
8
7
28
27
29
28
__all__ = ('GoogleVertexProvider' ,)
30
29
31
- # default expiry is 3600 seconds
32
- MAX_TOKEN_AGE = timedelta (seconds = 3000 )
33
-
34
30
35
31
class GoogleVertexProvider (Provider [httpx .AsyncClient ]):
36
32
"""Provider for Vertex AI API."""
@@ -131,19 +127,21 @@ def __init__(
131
127
self .region = region
132
128
133
129
self .credentials = None
134
- self .token_created : datetime | None = None
135
130
136
131
async def async_auth_flow (self , request : httpx .Request ) -> AsyncGenerator [httpx .Request , httpx .Response ]:
137
132
if self .credentials is None :
138
133
self .credentials = await self ._get_credentials ()
139
- if self .credentials .token is None or self ._token_expired (): # type: ignore[reportUnknownMemberType]
140
- await anyio .to_thread .run_sync (self ._refresh_token )
141
- self .token_created = datetime .now ()
134
+ if self .credentials .token is None : # type: ignore[reportUnknownMemberType]
135
+ await self ._refresh_token ()
142
136
request .headers ['Authorization' ] = f'Bearer { self .credentials .token } ' # type: ignore[reportUnknownMemberType]
143
-
144
137
# NOTE: This workaround is in place because we might get the project_id from the credentials.
145
138
request .url = httpx .URL (str (request .url ).replace ('projects/None' , f'projects/{ self .project_id } ' ))
146
- yield request
139
+ response = yield request
140
+
141
+ if response .status_code == 401 :
142
+ await self ._refresh_token ()
143
+ request .headers ['Authorization' ] = f'Bearer { self .credentials .token } ' # type: ignore[reportUnknownMemberType]
144
+ yield request
147
145
148
146
async def _get_credentials (self ) -> BaseCredentials | ServiceAccountCredentials :
149
147
if self .service_account_file is not None :
@@ -166,15 +164,9 @@ async def _get_credentials(self) -> BaseCredentials | ServiceAccountCredentials:
166
164
self .project_id = creds_project_id
167
165
return creds
168
166
169
- def _token_expired (self ) -> bool :
170
- if self .token_created is None :
171
- return True
172
- else :
173
- return (datetime .now () - self .token_created ) > MAX_TOKEN_AGE
174
-
175
- def _refresh_token (self ) -> str : # pragma: no cover
167
+ async def _refresh_token (self ) -> str : # pragma: no cover
176
168
assert self .credentials is not None
177
- self .credentials .refresh ( Request ()) # type: ignore[reportUnknownMemberType]
169
+ await anyio . to_thread . run_sync ( self .credentials .refresh , Request ()) # type: ignore[reportUnknownMemberType]
178
170
assert isinstance (self .credentials .token , str ), f'Expected token to be a string, got { self .credentials .token } ' # type: ignore[reportUnknownMemberType]
179
171
return self .credentials .token
180
172
0 commit comments