Skip to content

Commit 81cf2e7

Browse files
committed
amp/client: Keep auth token fresh (refresh on every request)
1 parent f90ec8e commit 81cf2e7

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

src/amp/client.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,33 +24,33 @@
2424
class AuthMiddleware(ClientMiddleware):
2525
"""Flight middleware to add Bearer token authentication header."""
2626

27-
def __init__(self, token: str):
27+
def __init__(self, get_token):
2828
"""Initialize auth middleware.
2929
3030
Args:
31-
token: Bearer token to add to requests
31+
get_token: Callable that returns the current access token
3232
"""
33-
self.token = token
33+
self.get_token = get_token
3434

3535
def sending_headers(self):
3636
"""Add Authorization header to outgoing requests."""
37-
return {'authorization': f'Bearer {self.token}'}
37+
return {'authorization': f'Bearer {self.get_token()}'}
3838

3939

4040
class AuthMiddlewareFactory(ClientMiddlewareFactory):
4141
"""Factory for creating auth middleware instances."""
4242

43-
def __init__(self, token: str):
43+
def __init__(self, get_token):
4444
"""Initialize auth middleware factory.
4545
4646
Args:
47-
token: Bearer token to use for authentication
47+
get_token: Callable that returns the current access token
4848
"""
49-
self.token = token
49+
self.get_token = get_token
5050

5151
def start_call(self, info):
5252
"""Create auth middleware for each call."""
53-
return AuthMiddleware(self.token)
53+
return AuthMiddleware(self.get_token)
5454

5555

5656
class QueryBuilder:
@@ -307,26 +307,27 @@ def __init__(
307307
if url and not query_url:
308308
query_url = url
309309

310-
# Resolve auth token with priority: explicit param > env var > auth file
311-
flight_auth_token = None
310+
# Resolve auth token provider with priority: explicit param > env var > auth file
311+
get_token = None
312312
if auth_token:
313-
# Priority 1: Explicit auth_token parameter
314-
flight_auth_token = auth_token
313+
# Priority 1: Explicit auth_token parameter (static token)
314+
get_token = lambda: auth_token
315315
elif os.getenv('AMP_AUTH_TOKEN'):
316-
# Priority 2: AMP_AUTH_TOKEN environment variable
317-
flight_auth_token = os.getenv('AMP_AUTH_TOKEN')
316+
# Priority 2: AMP_AUTH_TOKEN environment variable (static token)
317+
env_token = os.getenv('AMP_AUTH_TOKEN')
318+
get_token = lambda: env_token
318319
elif auth:
319-
# Priority 3: Load from ~/.amp-cli-config/amp_cli_auth
320+
# Priority 3: Load from ~/.amp-cli-config/amp_cli_auth (auto-refreshing)
320321
from amp.auth import AuthService
321322

322323
auth_service = AuthService()
323-
flight_auth_token = auth_service.get_token()
324+
get_token = auth_service.get_token # Callable that auto-refreshes
324325

325326
# Initialize Flight SQL client
326327
if query_url:
327-
# Add auth middleware if token is provided
328-
if flight_auth_token:
329-
middleware = [AuthMiddlewareFactory(flight_auth_token)]
328+
# Add auth middleware if token provider exists
329+
if get_token:
330+
middleware = [AuthMiddlewareFactory(get_token)]
330331
self.conn = flight.connect(query_url, middleware=middleware)
331332
else:
332333
self.conn = flight.connect(query_url)
@@ -342,8 +343,9 @@ def __init__(
342343
if admin_url:
343344
from amp.admin.client import AdminClient
344345

345-
# Pass resolved token to AdminClient (maintains same priority logic)
346-
self._admin_client = AdminClient(admin_url, auth_token=flight_auth_token, auth=False)
346+
# Pass resolved token to AdminClient (get current token if available)
347+
admin_token = get_token() if get_token else None
348+
self._admin_client = AdminClient(admin_url, auth_token=admin_token, auth=False)
347349
else:
348350
self._admin_client = None
349351

0 commit comments

Comments
 (0)