Skip to content

Commit 6fce5c0

Browse files
committed
feat(auth): Support multiple ways of passing in auth
- `AMP_AUTH_TOKEN` env var, `auth_token` param, or locally stored auth file (from interactive browser login)
1 parent 9e58bb6 commit 6fce5c0

File tree

3 files changed

+198
-16
lines changed

3 files changed

+198
-16
lines changed

src/amp/admin/client.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
with the Amp Admin API over HTTP.
55
"""
66

7+
import os
78
from typing import Optional
89

910
import httpx
@@ -19,15 +20,24 @@ class AdminClient:
1920
2021
Args:
2122
base_url: Base URL for Admin API (e.g., 'http://localhost:8080')
22-
auth_token: Optional Bearer token for authentication
23+
auth_token: Optional Bearer token for authentication (highest priority)
2324
auth: If True, load auth token from ~/.amp-cli-config (shared with TS CLI)
2425
26+
Authentication Priority (highest to lowest):
27+
1. Explicit auth_token parameter
28+
2. AMP_AUTH_TOKEN environment variable
29+
3. auth=True - reads from ~/.amp-cli-config/amp_cli_auth
30+
2531
Example:
26-
>>> # Use amp auth system
32+
>>> # Use amp auth from file
2733
>>> client = AdminClient('http://localhost:8080', auth=True)
2834
>>>
29-
>>> # Or use manual token
35+
>>> # Use manual token
3036
>>> client = AdminClient('http://localhost:8080', auth_token='your-token')
37+
>>>
38+
>>> # Use environment variable
39+
>>> # export AMP_AUTH_TOKEN="eyJhbGci..."
40+
>>> client = AdminClient('http://localhost:8080')
3141
"""
3242

3343
def __init__(self, base_url: str, auth_token: Optional[str] = None, auth: bool = False):
@@ -46,17 +56,25 @@ def __init__(self, base_url: str, auth_token: Optional[str] = None, auth: bool =
4656

4757
self.base_url = base_url.rstrip('/')
4858

49-
# Load token from amp auth system if requested
50-
if auth:
59+
# Resolve auth token with priority: explicit param > env var > auth file
60+
resolved_token = None
61+
if auth_token:
62+
# Priority 1: Explicit auth_token parameter
63+
resolved_token = auth_token
64+
elif os.getenv('AMP_AUTH_TOKEN'):
65+
# Priority 2: AMP_AUTH_TOKEN environment variable
66+
resolved_token = os.getenv('AMP_AUTH_TOKEN')
67+
elif auth:
68+
# Priority 3: Load from ~/.amp-cli-config/amp_cli_auth
5169
from amp.auth import AuthService
5270

5371
auth_service = AuthService()
54-
auth_token = auth_service.get_token()
72+
resolved_token = auth_service.get_token()
5573

5674
# Build headers
5775
headers = {}
58-
if auth_token:
59-
headers['Authorization'] = f'Bearer {auth_token}'
76+
if resolved_token:
77+
headers['Authorization'] = f'Bearer {resolved_token}'
6078

6179
# Create HTTP client
6280
self._http = httpx.Client(

src/amp/client.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
from typing import Dict, Iterator, List, Optional, Union
34

45
import pyarrow as pa
@@ -270,19 +271,28 @@ class Client:
270271
url: Flight SQL URL (for backward compatibility, treated as query_url)
271272
query_url: Query endpoint URL via Flight SQL (e.g., 'grpc://localhost:1602')
272273
admin_url: Optional Admin API URL (e.g., 'http://localhost:8080')
273-
auth_token: Optional Bearer token for Admin API authentication
274+
auth_token: Optional Bearer token for authentication (highest priority)
274275
auth: If True, load auth token from ~/.amp-cli-config (shared with TS CLI)
275276
277+
Authentication Priority (highest to lowest):
278+
1. Explicit auth_token parameter
279+
2. AMP_AUTH_TOKEN environment variable
280+
3. auth=True - reads from ~/.amp-cli-config/amp_cli_auth
281+
276282
Example:
277283
>>> # Query-only client (backward compatible)
278284
>>> client = Client(url='grpc://localhost:1602')
279285
>>>
280-
>>> # Client with admin capabilities and amp auth
286+
>>> # Client with amp auth from file
281287
>>> client = Client(
282288
... query_url='grpc://localhost:1602',
283289
... admin_url='http://localhost:8080',
284290
... auth=True
285291
... )
292+
>>>
293+
>>> # Client with auth from environment variable
294+
>>> # export AMP_AUTH_TOKEN="eyJhbGci..."
295+
>>> client = Client(query_url='grpc://localhost:1602')
286296
"""
287297

288298
def __init__(
@@ -297,15 +307,20 @@ def __init__(
297307
if url and not query_url:
298308
query_url = url
299309

300-
# Get auth token if using amp auth system
310+
# Resolve auth token with priority: explicit param > env var > auth file
301311
flight_auth_token = None
302-
if auth and not auth_token:
312+
if auth_token:
313+
# Priority 1: Explicit auth_token parameter
314+
flight_auth_token = auth_token
315+
elif os.getenv('AMP_AUTH_TOKEN'):
316+
# Priority 2: AMP_AUTH_TOKEN environment variable
317+
flight_auth_token = os.getenv('AMP_AUTH_TOKEN')
318+
elif auth:
319+
# Priority 3: Load from ~/.amp-cli-config/amp_cli_auth
303320
from amp.auth import AuthService
304321

305322
auth_service = AuthService()
306323
flight_auth_token = auth_service.get_token()
307-
elif auth_token:
308-
flight_auth_token = auth_token
309324

310325
# Initialize Flight SQL client
311326
if query_url:
@@ -327,7 +342,8 @@ def __init__(
327342
if admin_url:
328343
from amp.admin.client import AdminClient
329344

330-
self._admin_client = AdminClient(admin_url, auth_token=auth_token, auth=auth)
345+
# Pass resolved token to AdminClient (maintains same priority logic)
346+
self._admin_client = AdminClient(admin_url, auth_token=flight_auth_token, auth=False)
331347
else:
332348
self._admin_client = None
333349

tests/unit/test_client.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import json
99
from pathlib import Path
10-
from unittest.mock import Mock
10+
from unittest.mock import Mock, patch
1111

1212
import pytest
1313

@@ -87,6 +87,154 @@ def test_client_requires_url_or_query_url(self):
8787
Client()
8888

8989

90+
@pytest.mark.unit
91+
class TestClientAuthPriority:
92+
"""Test Client authentication priority (explicit token > env var > auth file)"""
93+
94+
@patch('amp.client.os.getenv')
95+
@patch('amp.client.flight.connect')
96+
def test_explicit_token_highest_priority(self, mock_connect, mock_getenv):
97+
"""Test that explicit auth_token parameter has highest priority"""
98+
mock_getenv.return_value = 'env-var-token'
99+
100+
client = Client(query_url='grpc://localhost:1602', auth_token='explicit-token')
101+
102+
# Verify that explicit token was used (not env var)
103+
mock_connect.assert_called_once()
104+
call_args = mock_connect.call_args
105+
middleware = call_args[1].get('middleware', [])
106+
assert len(middleware) == 1
107+
assert middleware[0].token == 'explicit-token'
108+
109+
@patch('amp.client.os.getenv')
110+
@patch('amp.client.flight.connect')
111+
def test_env_var_second_priority(self, mock_connect, mock_getenv):
112+
"""Test that AMP_AUTH_TOKEN env var has second priority"""
113+
114+
# Return 'env-var-token' for AMP_AUTH_TOKEN, None for others
115+
def getenv_side_effect(key, default=None):
116+
if key == 'AMP_AUTH_TOKEN':
117+
return 'env-var-token'
118+
return default
119+
120+
mock_getenv.side_effect = getenv_side_effect
121+
122+
client = Client(query_url='grpc://localhost:1602')
123+
124+
# Verify env var was checked
125+
calls = [str(call) for call in mock_getenv.call_args_list]
126+
assert any('AMP_AUTH_TOKEN' in call for call in calls)
127+
mock_connect.assert_called_once()
128+
call_args = mock_connect.call_args
129+
middleware = call_args[1].get('middleware', [])
130+
assert len(middleware) == 1
131+
assert middleware[0].token == 'env-var-token'
132+
133+
@patch('amp.auth.AuthService')
134+
@patch('amp.client.os.getenv')
135+
@patch('amp.client.flight.connect')
136+
def test_auth_file_lowest_priority(self, mock_connect, mock_getenv, mock_auth_service):
137+
"""Test that auth=True has lowest priority"""
138+
139+
# Return None for all getenv calls
140+
def getenv_side_effect(key, default=None):
141+
return default
142+
143+
mock_getenv.side_effect = getenv_side_effect
144+
145+
mock_service_instance = Mock()
146+
mock_service_instance.get_token.return_value = 'file-token'
147+
mock_auth_service.return_value = mock_service_instance
148+
149+
client = Client(query_url='grpc://localhost:1602', auth=True)
150+
151+
# Verify auth file was used
152+
mock_auth_service.assert_called_once()
153+
mock_service_instance.get_token.assert_called_once()
154+
mock_connect.assert_called_once()
155+
call_args = mock_connect.call_args
156+
middleware = call_args[1].get('middleware', [])
157+
assert len(middleware) == 1
158+
assert middleware[0].token == 'file-token'
159+
160+
@patch('amp.client.os.getenv')
161+
@patch('amp.client.flight.connect')
162+
def test_no_auth_when_nothing_provided(self, mock_connect, mock_getenv):
163+
"""Test that no auth middleware is added when no auth is provided"""
164+
165+
# Return None/default for all getenv calls
166+
def getenv_side_effect(key, default=None):
167+
return default
168+
169+
mock_getenv.side_effect = getenv_side_effect
170+
171+
client = Client(query_url='grpc://localhost:1602')
172+
173+
# Verify no middleware was added
174+
mock_connect.assert_called_once()
175+
call_args = mock_connect.call_args
176+
middleware = call_args[1].get('middleware')
177+
assert middleware is None or len(middleware) == 0
178+
179+
180+
@pytest.mark.unit
181+
class TestAdminClientAuthPriority:
182+
"""Test AdminClient authentication priority"""
183+
184+
@patch('amp.admin.client.os.getenv')
185+
def test_admin_explicit_token_highest_priority(self, mock_getenv):
186+
"""Test that explicit auth_token parameter has highest priority for AdminClient"""
187+
from amp.admin.client import AdminClient
188+
189+
mock_getenv.return_value = 'env-var-token'
190+
191+
client = AdminClient('http://localhost:8080', auth_token='explicit-token')
192+
193+
# Verify explicit token was used
194+
assert client._http.headers.get('Authorization') == 'Bearer explicit-token'
195+
196+
@patch('amp.admin.client.os.getenv')
197+
def test_admin_env_var_second_priority(self, mock_getenv):
198+
"""Test that AMP_AUTH_TOKEN env var has second priority for AdminClient"""
199+
from amp.admin.client import AdminClient
200+
201+
mock_getenv.return_value = 'env-var-token'
202+
203+
client = AdminClient('http://localhost:8080')
204+
205+
# Verify env var was used
206+
mock_getenv.assert_called_with('AMP_AUTH_TOKEN')
207+
assert client._http.headers.get('Authorization') == 'Bearer env-var-token'
208+
209+
@patch('amp.auth.AuthService')
210+
@patch('amp.admin.client.os.getenv')
211+
def test_admin_auth_file_lowest_priority(self, mock_getenv, mock_auth_service):
212+
"""Test that auth=True has lowest priority for AdminClient"""
213+
from amp.admin.client import AdminClient
214+
215+
mock_getenv.return_value = None
216+
mock_service_instance = Mock()
217+
mock_service_instance.get_token.return_value = 'file-token'
218+
mock_auth_service.return_value = mock_service_instance
219+
220+
client = AdminClient('http://localhost:8080', auth=True)
221+
222+
# Verify auth file was used
223+
assert client._http.headers.get('Authorization') == 'Bearer file-token'
224+
225+
@patch('amp.admin.client.os.getenv')
226+
def test_admin_no_auth_when_nothing_provided(self, mock_getenv):
227+
"""Test that no auth header is added when no auth is provided"""
228+
from amp.admin.client import AdminClient
229+
230+
mock_getenv.return_value = None
231+
232+
client = AdminClient('http://localhost:8080')
233+
234+
# Verify no auth header
235+
assert 'Authorization' not in client._http.headers
236+
237+
90238
@pytest.mark.unit
91239
class TestQueryBuilderManifest:
92240
"""Test QueryBuilder manifest generation"""

0 commit comments

Comments
 (0)