2424class AuthMiddleware (ClientMiddleware ):
2525 """Flight middleware to add Bearer token authentication header."""
2626
27- def __init__ (self , token : str ):
27+ def __init__ (self , get_token ):
2828 """Initialize auth middleware.
2929
3030 Args:
31- token: Bearer token to add to requests
31+ get_token: Callable that returns the current access token
3232 """
33- self .token = token
33+ self .get_token = get_token
3434
3535 def sending_headers (self ):
3636 """Add Authorization header to outgoing requests."""
37- return {'authorization' : f'Bearer { self .token } ' }
37+ return {'authorization' : f'Bearer { self .get_token () } ' }
3838
3939
4040class AuthMiddlewareFactory (ClientMiddlewareFactory ):
4141 """Factory for creating auth middleware instances."""
4242
43- def __init__ (self , token : str ):
43+ def __init__ (self , get_token ):
4444 """Initialize auth middleware factory.
4545
4646 Args:
47- token: Bearer token to use for authentication
47+ get_token: Callable that returns the current access token
4848 """
49- self .token = token
49+ self .get_token = get_token
5050
5151 def start_call (self , info ):
5252 """Create auth middleware for each call."""
53- return AuthMiddleware (self .token )
53+ return AuthMiddleware (self .get_token )
5454
5555
5656class QueryBuilder :
@@ -307,26 +307,27 @@ def __init__(
307307 if url and not query_url :
308308 query_url = url
309309
310- # Resolve auth token with priority: explicit param > env var > auth file
311- flight_auth_token = None
310+ # Resolve auth token provider with priority: explicit param > env var > auth file
311+ get_token = None
312312 if auth_token :
313- # Priority 1: Explicit auth_token parameter
314- flight_auth_token = auth_token
313+ # Priority 1: Explicit auth_token parameter (static token)
314+ get_token = lambda : auth_token
315315 elif os .getenv ('AMP_AUTH_TOKEN' ):
316- # Priority 2: AMP_AUTH_TOKEN environment variable
317- flight_auth_token = os .getenv ('AMP_AUTH_TOKEN' )
316+ # Priority 2: AMP_AUTH_TOKEN environment variable (static token)
317+ env_token = os .getenv ('AMP_AUTH_TOKEN' )
318+ get_token = lambda : env_token
318319 elif auth :
319- # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth
320+ # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth (auto-refreshing)
320321 from amp .auth import AuthService
321322
322323 auth_service = AuthService ()
323- flight_auth_token = auth_service .get_token ()
324+ get_token = auth_service .get_token # Callable that auto-refreshes
324325
325326 # Initialize Flight SQL client
326327 if query_url :
327- # Add auth middleware if token is provided
328- if flight_auth_token :
329- middleware = [AuthMiddlewareFactory (flight_auth_token )]
328+ # Add auth middleware if token provider exists
329+ if get_token :
330+ middleware = [AuthMiddlewareFactory (get_token )]
330331 self .conn = flight .connect (query_url , middleware = middleware )
331332 else :
332333 self .conn = flight .connect (query_url )
@@ -342,8 +343,9 @@ def __init__(
342343 if admin_url :
343344 from amp .admin .client import AdminClient
344345
345- # Pass resolved token to AdminClient (maintains same priority logic)
346- self ._admin_client = AdminClient (admin_url , auth_token = flight_auth_token , auth = False )
346+ # Pass resolved token to AdminClient (get current token if available)
347+ admin_token = get_token () if get_token else None
348+ self ._admin_client = AdminClient (admin_url , auth_token = admin_token , auth = False )
347349 else :
348350 self ._admin_client = None
349351
0 commit comments