Skip to content

Commit cc9de47

Browse files
committed
client, admin-client: Keep auth token fresh
1 parent f90ec8e commit cc9de47

File tree

5 files changed

+61
-47
lines changed

5 files changed

+61
-47
lines changed

src/amp/admin/client.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,30 +56,25 @@ def __init__(self, base_url: str, auth_token: Optional[str] = None, auth: bool =
5656

5757
self.base_url = base_url.rstrip('/')
5858

59-
# Resolve auth token with priority: explicit param > env var > auth file
60-
resolved_token = None
59+
# Resolve auth token provider with priority: explicit param > env var > auth file
60+
self._get_token = None
6161
if auth_token:
62-
# Priority 1: Explicit auth_token parameter
63-
resolved_token = auth_token
62+
# Priority 1: Explicit auth_token parameter (static token)
63+
self._get_token = lambda: auth_token
6464
elif os.getenv('AMP_AUTH_TOKEN'):
65-
# Priority 2: AMP_AUTH_TOKEN environment variable
66-
resolved_token = os.getenv('AMP_AUTH_TOKEN')
65+
# Priority 2: AMP_AUTH_TOKEN environment variable (static token)
66+
env_token = os.getenv('AMP_AUTH_TOKEN')
67+
self._get_token = lambda: env_token
6768
elif auth:
68-
# Priority 3: Load from ~/.amp-cli-config/amp_cli_auth
69+
# Priority 3: Load from ~/.amp-cli-config/amp_cli_auth (auto-refreshing)
6970
from amp.auth import AuthService
7071

7172
auth_service = AuthService()
72-
resolved_token = auth_service.get_token()
73+
self._get_token = auth_service.get_token # Callable that auto-refreshes
7374

74-
# Build headers
75-
headers = {}
76-
if resolved_token:
77-
headers['Authorization'] = f'Bearer {resolved_token}'
78-
79-
# Create HTTP client
75+
# Create HTTP client (no auth header yet - will be added per-request)
8076
self._http = httpx.Client(
8177
base_url=self.base_url,
82-
headers=headers,
8378
timeout=30.0,
8479
follow_redirects=True,
8580
)
@@ -102,6 +97,12 @@ def _request(
10297
Raises:
10398
AdminAPIError: If the API returns an error response
10499
"""
100+
# Add auth header dynamically (auto-refreshes if needed)
101+
headers = kwargs.get('headers', {})
102+
if self._get_token:
103+
headers['Authorization'] = f'Bearer {self._get_token()}'
104+
kwargs['headers'] = headers
105+
105106
response = self._http.request(method, path, json=json, params=params, **kwargs)
106107

107108
# Handle error responses

src/amp/client.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,33 +24,33 @@
2424
class AuthMiddleware(ClientMiddleware):
2525
"""Flight middleware to add Bearer token authentication header."""
2626

27-
def __init__(self, token: str):
27+
def __init__(self, get_token):
2828
"""Initialize auth middleware.
2929
3030
Args:
31-
token: Bearer token to add to requests
31+
get_token: Callable that returns the current access token
3232
"""
33-
self.token = token
33+
self.get_token = get_token
3434

3535
def sending_headers(self):
3636
"""Add Authorization header to outgoing requests."""
37-
return {'authorization': f'Bearer {self.token}'}
37+
return {'authorization': f'Bearer {self.get_token()}'}
3838

3939

4040
class AuthMiddlewareFactory(ClientMiddlewareFactory):
4141
"""Factory for creating auth middleware instances."""
4242

43-
def __init__(self, token: str):
43+
def __init__(self, get_token):
4444
"""Initialize auth middleware factory.
4545
4646
Args:
47-
token: Bearer token to use for authentication
47+
get_token: Callable that returns the current access token
4848
"""
49-
self.token = token
49+
self.get_token = get_token
5050

5151
def start_call(self, info):
5252
"""Create auth middleware for each call."""
53-
return AuthMiddleware(self.token)
53+
return AuthMiddleware(self.get_token)
5454

5555

5656
class QueryBuilder:
@@ -307,26 +307,30 @@ def __init__(
307307
if url and not query_url:
308308
query_url = url
309309

310-
# Resolve auth token with priority: explicit param > env var > auth file
311-
flight_auth_token = None
310+
# Resolve auth token provider with priority: explicit param > env var > auth file
311+
get_token = None
312312
if auth_token:
313-
# Priority 1: Explicit auth_token parameter
314-
flight_auth_token = auth_token
313+
# Priority 1: Explicit auth_token parameter (static token)
314+
def get_token():
315+
return auth_token
315316
elif os.getenv('AMP_AUTH_TOKEN'):
316-
# Priority 2: AMP_AUTH_TOKEN environment variable
317-
flight_auth_token = os.getenv('AMP_AUTH_TOKEN')
317+
# Priority 2: AMP_AUTH_TOKEN environment variable (static token)
318+
env_token = os.getenv('AMP_AUTH_TOKEN')
319+
320+
def get_token():
321+
return env_token
318322
elif auth:
319-
# Priority 3: Load from ~/.amp-cli-config/amp_cli_auth
323+
# Priority 3: Load from ~/.amp-cli-config/amp_cli_auth (auto-refreshing)
320324
from amp.auth import AuthService
321325

322326
auth_service = AuthService()
323-
flight_auth_token = auth_service.get_token()
327+
get_token = auth_service.get_token # Callable that auto-refreshes
324328

325329
# Initialize Flight SQL client
326330
if query_url:
327-
# Add auth middleware if token is provided
328-
if flight_auth_token:
329-
middleware = [AuthMiddlewareFactory(flight_auth_token)]
331+
# Add auth middleware if token provider exists
332+
if get_token:
333+
middleware = [AuthMiddlewareFactory(get_token)]
330334
self.conn = flight.connect(query_url, middleware=middleware)
331335
else:
332336
self.conn = flight.connect(query_url)
@@ -342,8 +346,18 @@ def __init__(
342346
if admin_url:
343347
from amp.admin.client import AdminClient
344348

345-
# Pass resolved token to AdminClient (maintains same priority logic)
346-
self._admin_client = AdminClient(admin_url, auth_token=flight_auth_token, auth=False)
349+
# Pass auth=True if we have a get_token callable from auth file
350+
# Otherwise pass the static token if available
351+
if auth:
352+
# Use auth file (auto-refreshing)
353+
self._admin_client = AdminClient(admin_url, auth=True)
354+
elif auth_token or os.getenv('AMP_AUTH_TOKEN'):
355+
# Use static token
356+
static_token = auth_token or os.getenv('AMP_AUTH_TOKEN')
357+
self._admin_client = AdminClient(admin_url, auth_token=static_token)
358+
else:
359+
# No auth
360+
self._admin_client = AdminClient(admin_url)
347361
else:
348362
self._admin_client = None
349363

tests/integration/admin/test_admin_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def test_admin_client_with_auth_token(self):
2525
"""Test AdminClient with authentication token."""
2626
client = AdminClient('http://localhost:8080', auth_token='test-token')
2727

28-
assert 'Authorization' in client._http.headers
29-
assert client._http.headers['Authorization'] == 'Bearer test-token'
28+
assert client._get_token() == 'test-token'
3029

3130
@respx.mock
3231
def test_request_success(self):

tests/integration/test_snowflake_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def wait_for_snowpipe_data(loader, table_name, expected_count, max_wait=30, poll
6767

6868

6969
# Skip all Snowflake tests
70-
# pytestmark = pytest.mark.skip(reason='Requires active Snowflake account - see module docstring for details')
70+
pytestmark = pytest.mark.skip(reason='Requires active Snowflake account - see module docstring for details')
7171

7272

7373
@pytest.fixture

tests/unit/test_client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_explicit_token_highest_priority(self, mock_connect, mock_getenv):
104104
call_args = mock_connect.call_args
105105
middleware = call_args[1].get('middleware', [])
106106
assert len(middleware) == 1
107-
assert middleware[0].token == 'explicit-token'
107+
assert middleware[0].get_token() == 'explicit-token'
108108

109109
@patch('amp.client.os.getenv')
110110
@patch('amp.client.flight.connect')
@@ -128,7 +128,7 @@ def getenv_side_effect(key, default=None):
128128
call_args = mock_connect.call_args
129129
middleware = call_args[1].get('middleware', [])
130130
assert len(middleware) == 1
131-
assert middleware[0].token == 'env-var-token'
131+
assert middleware[0].get_token() == 'env-var-token'
132132

133133
@patch('amp.auth.AuthService')
134134
@patch('amp.client.os.getenv')
@@ -150,12 +150,12 @@ def getenv_side_effect(key, default=None):
150150

151151
# Verify auth file was used
152152
mock_auth_service.assert_called_once()
153-
mock_service_instance.get_token.assert_called_once()
154153
mock_connect.assert_called_once()
155154
call_args = mock_connect.call_args
156155
middleware = call_args[1].get('middleware', [])
157156
assert len(middleware) == 1
158-
assert middleware[0].token == 'file-token'
157+
# The middleware should use the auth service's get_token method directly
158+
assert middleware[0].get_token == mock_service_instance.get_token
159159

160160
@patch('amp.client.os.getenv')
161161
@patch('amp.client.flight.connect')
@@ -190,8 +190,8 @@ def test_admin_explicit_token_highest_priority(self, mock_getenv):
190190

191191
client = AdminClient('http://localhost:8080', auth_token='explicit-token')
192192

193-
# Verify explicit token was used
194-
assert client._http.headers.get('Authorization') == 'Bearer explicit-token'
193+
# Verify explicit token was used (check get_token callable)
194+
assert client._get_token() == 'explicit-token'
195195

196196
@patch('amp.admin.client.os.getenv')
197197
def test_admin_env_var_second_priority(self, mock_getenv):
@@ -204,7 +204,7 @@ def test_admin_env_var_second_priority(self, mock_getenv):
204204

205205
# Verify env var was used
206206
mock_getenv.assert_called_with('AMP_AUTH_TOKEN')
207-
assert client._http.headers.get('Authorization') == 'Bearer env-var-token'
207+
assert client._get_token() == 'env-var-token'
208208

209209
@patch('amp.auth.AuthService')
210210
@patch('amp.admin.client.os.getenv')
@@ -220,7 +220,7 @@ def test_admin_auth_file_lowest_priority(self, mock_getenv, mock_auth_service):
220220
client = AdminClient('http://localhost:8080', auth=True)
221221

222222
# Verify auth file was used
223-
assert client._http.headers.get('Authorization') == 'Bearer file-token'
223+
assert client._get_token == mock_service_instance.get_token
224224

225225
@patch('amp.admin.client.os.getenv')
226226
def test_admin_no_auth_when_nothing_provided(self, mock_getenv):

0 commit comments

Comments
 (0)