Skip to content

Commit 907740e

Browse files
authored
Add authentication for hosted amp server (#18)
* feat(adminClient): Add auth support for admin requests - Add auth module with models and services - Use new auth for AdminClient requests * feat(client): Add auth support to query client (FlightSQL gRPC) * feat(auth): Add interactive device auth flow via browser * feat(auth): Support multiple ways of passing in auth - `AMP_AUTH_TOKEN` env var, `auth_token` param, or locally stored auth file (from interactive browser login) * fix(auth/service): Update Auth platform url (use thegraph domain) * tests: Add unit tests for auth service * client, admin-client: Keep auth token fresh * auth: Update location
1 parent f538843 commit 907740e

File tree

10 files changed

+1223
-18
lines changed

10 files changed

+1223
-18
lines changed

src/amp/admin/client.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
with the Amp Admin API over HTTP.
55
"""
66

7+
import os
78
from typing import Optional
89

910
import httpx
@@ -19,31 +20,61 @@ class AdminClient:
1920
2021
Args:
2122
base_url: Base URL for Admin API (e.g., 'http://localhost:8080')
22-
auth_token: Optional Bearer token for authentication
23+
auth_token: Optional Bearer token for authentication (highest priority)
24+
auth: If True, load auth token from ~/.amp/cache (shared with TS CLI)
25+
26+
Authentication Priority (highest to lowest):
27+
1. Explicit auth_token parameter
28+
2. AMP_AUTH_TOKEN environment variable
29+
3. auth=True - reads from ~/.amp/cache/amp_cli_auth
2330
2431
Example:
32+
>>> # Use amp auth from file
33+
>>> client = AdminClient('http://localhost:8080', auth=True)
34+
>>>
35+
>>> # Use manual token
36+
>>> client = AdminClient('http://localhost:8080', auth_token='your-token')
37+
>>>
38+
>>> # Use environment variable
39+
>>> # export AMP_AUTH_TOKEN="eyJhbGci..."
2540
>>> client = AdminClient('http://localhost:8080')
26-
>>> datasets = client.datasets.list_all()
2741
"""
2842

29-
def __init__(self, base_url: str, auth_token: Optional[str] = None):
43+
def __init__(self, base_url: str, auth_token: Optional[str] = None, auth: bool = False):
3044
"""Initialize Admin API client.
3145
3246
Args:
3347
base_url: Base URL for Admin API (e.g., 'http://localhost:8080')
3448
auth_token: Optional Bearer token for authentication
49+
auth: If True, load auth token from ~/.amp/cache
50+
51+
Raises:
52+
ValueError: If both auth=True and auth_token are provided
3553
"""
54+
if auth and auth_token:
55+
raise ValueError('Cannot specify both auth=True and auth_token. Choose one authentication method.')
56+
3657
self.base_url = base_url.rstrip('/')
3758

38-
# Build headers
39-
headers = {}
59+
# Resolve auth token provider with priority: explicit param > env var > auth file
60+
self._get_token = None
4061
if auth_token:
41-
headers['Authorization'] = f'Bearer {auth_token}'
42-
43-
# Create HTTP client
62+
# Priority 1: Explicit auth_token parameter (static token)
63+
self._get_token = lambda: auth_token
64+
elif os.getenv('AMP_AUTH_TOKEN'):
65+
# Priority 2: AMP_AUTH_TOKEN environment variable (static token)
66+
env_token = os.getenv('AMP_AUTH_TOKEN')
67+
self._get_token = lambda: env_token
68+
elif auth:
69+
# Priority 3: Load from ~/.amp-cli-config/amp_cli_auth (auto-refreshing)
70+
from amp.auth import AuthService
71+
72+
auth_service = AuthService()
73+
self._get_token = auth_service.get_token # Callable that auto-refreshes
74+
75+
# Create HTTP client (no auth header yet - will be added per-request)
4476
self._http = httpx.Client(
4577
base_url=self.base_url,
46-
headers=headers,
4778
timeout=30.0,
4879
follow_redirects=True,
4980
)
@@ -66,6 +97,12 @@ def _request(
6697
Raises:
6798
AdminAPIError: If the API returns an error response
6899
"""
100+
# Add auth header dynamically (auto-refreshes if needed)
101+
headers = kwargs.get('headers', {})
102+
if self._get_token:
103+
headers['Authorization'] = f'Bearer {self._get_token()}'
104+
kwargs['headers'] = headers
105+
69106
response = self._http.request(method, path, json=json, params=params, **kwargs)
70107

71108
# Handle error responses

src/amp/auth/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Authentication module for amp Python client.
2+
3+
Provides Privy authentication support compatible with the TypeScript CLI.
4+
Reads and manages auth tokens from ~/.amp/cache.
5+
"""
6+
7+
from .device_flow import interactive_device_login
8+
from .models import AuthStorage, RefreshTokenResponse
9+
from .service import AuthService
10+
11+
__all__ = ['AuthService', 'AuthStorage', 'RefreshTokenResponse', 'interactive_device_login']

src/amp/auth/device_flow.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
"""OAuth2 Device Authorization Flow for Privy authentication.
2+
3+
Implements the device authorization grant flow with PKCE for secure authentication.
4+
Matches the TypeScript CLI implementation.
5+
"""
6+
7+
import base64
8+
import hashlib
9+
import secrets
10+
import time
11+
import webbrowser
12+
from typing import Optional, Tuple
13+
14+
import httpx
15+
from pydantic import BaseModel, Field
16+
17+
from .models import AuthStorage
18+
from .service import AUTH_PLATFORM_URL
19+
20+
21+
class DeviceAuthorizationResponse(BaseModel):
22+
"""Response from device authorization endpoint."""
23+
24+
device_code: str = Field(..., description='Device verification code for polling')
25+
user_code: str = Field(..., description='Code for user to enter in browser')
26+
verification_uri: str = Field(..., description='URL where user enters the code')
27+
expires_in: int = Field(..., description='Seconds until device code expires')
28+
interval: int = Field(..., description='Minimum polling interval in seconds')
29+
30+
31+
class DeviceTokenResponse(BaseModel):
32+
"""Response from device token endpoint (success case)."""
33+
34+
access_token: str = Field(..., description='Access token for authenticated requests')
35+
refresh_token: str = Field(..., description='Refresh token for renewing access')
36+
user_id: str = Field(..., description='Authenticated user ID')
37+
user_accounts: list[str] = Field(..., description='List of user accounts/wallets')
38+
expires_in: int = Field(..., description='Seconds until token expires')
39+
40+
41+
class DeviceTokenPendingResponse(BaseModel):
42+
"""Response when authorization is still pending."""
43+
44+
error: str = Field('authorization_pending', description='Error code')
45+
46+
47+
class DeviceTokenExpiredResponse(BaseModel):
48+
"""Response when device code has expired."""
49+
50+
error: str = Field('expired_token', description='Error code')
51+
52+
53+
def generate_pkce_pair() -> Tuple[str, str]:
54+
"""Generate PKCE code_verifier and code_challenge.
55+
56+
Returns:
57+
Tuple of (code_verifier, code_challenge)
58+
"""
59+
# Generate cryptographically random code_verifier
60+
# Must be 43-128 characters using unreserved chars [A-Za-z0-9-._~]
61+
code_verifier_bytes = secrets.token_bytes(32)
62+
code_verifier = base64.urlsafe_b64encode(code_verifier_bytes).decode('utf-8').rstrip('=')
63+
64+
# Generate code_challenge = BASE64URL(SHA256(code_verifier))
65+
challenge_bytes = hashlib.sha256(code_verifier.encode('utf-8')).digest()
66+
code_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=')
67+
68+
return code_verifier, code_challenge
69+
70+
71+
def request_device_authorization(http_client: httpx.Client) -> Tuple[DeviceAuthorizationResponse, str]:
72+
"""Request device authorization from auth platform.
73+
74+
Args:
75+
http_client: HTTP client to use for request
76+
77+
Returns:
78+
Tuple of (DeviceAuthorizationResponse, code_verifier)
79+
80+
Raises:
81+
httpx.HTTPStatusError: If request fails
82+
ValueError: If response is invalid
83+
"""
84+
# Generate PKCE parameters
85+
code_verifier, code_challenge = generate_pkce_pair()
86+
87+
# Request device authorization
88+
url = f'{AUTH_PLATFORM_URL}api/v1/device/authorize'
89+
response = http_client.post(
90+
url, json={'code_challenge': code_challenge, 'code_challenge_method': 'S256'}, timeout=30.0
91+
)
92+
93+
if response.status_code != 200:
94+
raise ValueError(f'Device authorization failed: {response.status_code} - {response.text}')
95+
96+
device_auth = DeviceAuthorizationResponse.model_validate(response.json())
97+
return device_auth, code_verifier
98+
99+
100+
def poll_for_token(http_client: httpx.Client, device_code: str, code_verifier: str) -> Optional[DeviceTokenResponse]:
101+
"""Poll device token endpoint once.
102+
103+
Args:
104+
http_client: HTTP client to use for request
105+
device_code: Device code from authorization response
106+
code_verifier: PKCE code verifier
107+
108+
Returns:
109+
DeviceTokenResponse if auth complete, None if still pending
110+
111+
Raises:
112+
ValueError: If device code expired or other error
113+
"""
114+
url = f'{AUTH_PLATFORM_URL}api/v1/device/token'
115+
response = http_client.get(url, params={'device_code': device_code, 'code_verifier': code_verifier}, timeout=10.0)
116+
117+
data = response.json()
118+
119+
# Check for error responses
120+
if 'error' in data:
121+
error = data['error']
122+
if error == 'authorization_pending':
123+
return None # Still pending
124+
elif error == 'expired_token':
125+
raise ValueError('Device code expired. Please try again.')
126+
else:
127+
raise ValueError(f'Token polling error: {error}')
128+
129+
# Success - parse token response
130+
return DeviceTokenResponse.model_validate(data)
131+
132+
133+
def poll_until_authenticated(
134+
http_client: httpx.Client,
135+
device_code: str,
136+
code_verifier: str,
137+
interval: int,
138+
expires_in: int,
139+
on_poll: Optional[callable] = None,
140+
verbose: bool = False,
141+
) -> DeviceTokenResponse:
142+
"""Poll for token until authenticated or timeout.
143+
144+
Args:
145+
http_client: HTTP client to use for requests
146+
device_code: Device code from authorization
147+
code_verifier: PKCE code verifier
148+
interval: Minimum polling interval in seconds
149+
expires_in: Seconds until device code expires
150+
on_poll: Optional callback called on each poll attempt
151+
152+
Returns:
153+
DeviceTokenResponse when authentication completes
154+
155+
Raises:
156+
ValueError: If authentication times out or fails
157+
"""
158+
start_time = time.time()
159+
poll_count = 0
160+
max_polls = int(expires_in / interval) + 5 # Add some buffer
161+
162+
while poll_count < max_polls:
163+
elapsed = time.time() - start_time
164+
if elapsed > expires_in:
165+
raise ValueError('Authentication timed out. Please try again.')
166+
167+
if on_poll:
168+
on_poll(poll_count, elapsed)
169+
170+
# Poll for token
171+
try:
172+
token_response = poll_for_token(http_client, device_code, code_verifier)
173+
if token_response:
174+
return token_response
175+
except ValueError as e:
176+
if 'expired' in str(e).lower():
177+
raise
178+
# Other errors, log and continue polling
179+
if verbose:
180+
print(f'\n⚠ Polling error (will retry): {e}')
181+
pass
182+
except Exception as e:
183+
# Log unexpected errors
184+
if verbose:
185+
print(f'\n⚠ Unexpected error (will retry): {type(e).__name__}: {e}')
186+
pass
187+
188+
# Wait before next poll
189+
time.sleep(interval)
190+
poll_count += 1
191+
192+
raise ValueError('Authentication timed out. Please try again.')
193+
194+
195+
def open_browser(url: str) -> bool:
196+
"""Open URL in browser.
197+
198+
Args:
199+
url: URL to open
200+
201+
Returns:
202+
True if browser opened successfully
203+
"""
204+
try:
205+
webbrowser.open(url)
206+
return True
207+
except Exception:
208+
return False
209+
210+
211+
def interactive_device_login(verbose: bool = True, auto_open_browser: bool = True) -> AuthStorage:
212+
"""Perform interactive device authorization flow.
213+
214+
Args:
215+
verbose: Print progress messages
216+
auto_open_browser: Automatically open browser for user
217+
218+
Returns:
219+
AuthStorage with tokens
220+
221+
Raises:
222+
ValueError: If authentication fails
223+
"""
224+
http_client = httpx.Client()
225+
226+
try:
227+
# Step 1: Request device authorization
228+
if verbose:
229+
print('🔐 Starting authentication...\n')
230+
231+
device_auth, code_verifier = request_device_authorization(http_client)
232+
233+
# Step 2: Display user code and open browser
234+
if verbose:
235+
print(f'📱 Verification Code: {device_auth.user_code}')
236+
print(f'🌐 Verification URL: {device_auth.verification_uri}\n')
237+
238+
if auto_open_browser:
239+
if verbose:
240+
print('Opening browser...')
241+
if open_browser(device_auth.verification_uri):
242+
if verbose:
243+
print('✓ Browser opened')
244+
else:
245+
if verbose:
246+
print('✗ Could not open browser automatically')
247+
print(f' Please open: {device_auth.verification_uri}')
248+
249+
if verbose:
250+
print(f'\n⏳ Waiting for authentication (expires in {device_auth.expires_in}s)...')
251+
print(' Complete the authentication in your browser.\n')
252+
253+
# Step 3: Poll for token
254+
spinner_frames = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
255+
256+
def poll_callback(count: int, elapsed: float):
257+
if verbose:
258+
spinner = spinner_frames[count % len(spinner_frames)]
259+
print(f'\r{spinner} Polling... ({int(elapsed)}s elapsed)', end='', flush=True)
260+
261+
token_response = poll_until_authenticated(
262+
http_client,
263+
device_auth.device_code,
264+
code_verifier,
265+
device_auth.interval,
266+
device_auth.expires_in,
267+
poll_callback,
268+
verbose=verbose,
269+
)
270+
271+
if verbose:
272+
print('\r✓ Authentication successful! \n')
273+
274+
# Step 4: Create auth storage
275+
now_ms = int(time.time() * 1000)
276+
expiry_ms = now_ms + (token_response.expires_in * 1000)
277+
278+
auth_storage = AuthStorage(
279+
accessToken=token_response.access_token,
280+
refreshToken=token_response.refresh_token,
281+
userId=token_response.user_id,
282+
accounts=token_response.user_accounts,
283+
expiry=expiry_ms,
284+
)
285+
286+
return auth_storage
287+
288+
finally:
289+
http_client.close()

0 commit comments

Comments
 (0)