2424class AuthMiddleware (ClientMiddleware ):
2525 """Flight middleware to add Bearer token authentication header."""
2626
27- def __init__ (self , token : str ):
27+ def __init__ (self , get_token ):
2828 """Initialize auth middleware.
2929
3030 Args:
31- token: Bearer token to add to requests
31+ get_token: Callable that returns the current access token
3232 """
33- self .token = token
33+ self .get_token = get_token
3434
3535 def sending_headers (self ):
3636 """Add Authorization header to outgoing requests."""
37- return {'authorization' : f'Bearer { self .token } ' }
37+ return {'authorization' : f'Bearer { self .get_token () } ' }
3838
3939
4040class AuthMiddlewareFactory (ClientMiddlewareFactory ):
4141 """Factory for creating auth middleware instances."""
4242
43- def __init__ (self , token : str ):
43+ def __init__ (self , get_token ):
4444 """Initialize auth middleware factory.
4545
4646 Args:
47- token: Bearer token to use for authentication
47+ get_token: Callable that returns the current access token
4848 """
49- self .token = token
49+ self .get_token = get_token
5050
5151 def start_call (self , info ):
5252 """Create auth middleware for each call."""
53- return AuthMiddleware (self .token )
53+ return AuthMiddleware (self .get_token )
5454
5555
5656class QueryBuilder :
@@ -307,26 +307,30 @@ def __init__(
307307 if url and not query_url :
308308 query_url = url
309309
310- # Resolve auth token with priority: explicit param > env var > auth file
311- flight_auth_token = None
310+ # Resolve auth token provider with priority: explicit param > env var > auth file
311+ get_token = None
312312 if auth_token :
313- # Priority 1: Explicit auth_token parameter
314- flight_auth_token = auth_token
313+ # Priority 1: Explicit auth_token parameter (static token)
314+ def get_token ():
315+ return auth_token
315316 elif os .getenv ('AMP_AUTH_TOKEN' ):
316- # Priority 2: AMP_AUTH_TOKEN environment variable
317- flight_auth_token = os .getenv ('AMP_AUTH_TOKEN' )
317+ # Priority 2: AMP_AUTH_TOKEN environment variable (static token)
318+ env_token = os .getenv ('AMP_AUTH_TOKEN' )
319+
320+ def get_token ():
321+ return env_token
318322 elif auth :
319- # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth
323+ # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth (auto-refreshing)
320324 from amp .auth import AuthService
321325
322326 auth_service = AuthService ()
323- flight_auth_token = auth_service .get_token ()
327+ get_token = auth_service .get_token # Callable that auto-refreshes
324328
325329 # Initialize Flight SQL client
326330 if query_url :
327- # Add auth middleware if token is provided
328- if flight_auth_token :
329- middleware = [AuthMiddlewareFactory (flight_auth_token )]
331+ # Add auth middleware if token provider exists
332+ if get_token :
333+ middleware = [AuthMiddlewareFactory (get_token )]
330334 self .conn = flight .connect (query_url , middleware = middleware )
331335 else :
332336 self .conn = flight .connect (query_url )
@@ -342,8 +346,18 @@ def __init__(
342346 if admin_url :
343347 from amp .admin .client import AdminClient
344348
345- # Pass resolved token to AdminClient (maintains same priority logic)
346- self ._admin_client = AdminClient (admin_url , auth_token = flight_auth_token , auth = False )
349+ # Pass auth=True if we have a get_token callable from auth file
350+ # Otherwise pass the static token if available
351+ if auth :
352+ # Use auth file (auto-refreshing)
353+ self ._admin_client = AdminClient (admin_url , auth = True )
354+ elif auth_token or os .getenv ('AMP_AUTH_TOKEN' ):
355+ # Use static token
356+ static_token = auth_token or os .getenv ('AMP_AUTH_TOKEN' )
357+ self ._admin_client = AdminClient (admin_url , auth_token = static_token )
358+ else :
359+ # No auth
360+ self ._admin_client = AdminClient (admin_url )
347361 else :
348362 self ._admin_client = None
349363
0 commit comments