Skip to content

Commit 9e58bb6

Browse files
committed
feat(auth): Add interactive device auth flow via browser
1 parent cd64269 commit 9e58bb6

File tree

3 files changed

+321
-4
lines changed

3 files changed

+321
-4
lines changed

src/amp/auth/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Reads and manages auth tokens from ~/.amp-cli-config.
55
"""
66

7+
from .device_flow import interactive_device_login
78
from .models import AuthStorage, RefreshTokenResponse
89
from .service import AuthService
910

10-
__all__ = ['AuthService', 'AuthStorage', 'RefreshTokenResponse']
11+
__all__ = ['AuthService', 'AuthStorage', 'RefreshTokenResponse', 'interactive_device_login']

src/amp/auth/device_flow.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
"""OAuth2 Device Authorization Flow for Privy authentication.
2+
3+
Implements the device authorization grant flow with PKCE for secure authentication.
4+
Matches the TypeScript CLI implementation.
5+
"""
6+
7+
import base64
8+
import hashlib
9+
import secrets
10+
import time
11+
import webbrowser
12+
from typing import Optional, Tuple
13+
14+
import httpx
15+
from pydantic import BaseModel, Field
16+
17+
from .models import AuthStorage
18+
from .service import AUTH_PLATFORM_URL
19+
20+
21+
class DeviceAuthorizationResponse(BaseModel):
22+
"""Response from device authorization endpoint."""
23+
24+
device_code: str = Field(..., description='Device verification code for polling')
25+
user_code: str = Field(..., description='Code for user to enter in browser')
26+
verification_uri: str = Field(..., description='URL where user enters the code')
27+
expires_in: int = Field(..., description='Seconds until device code expires')
28+
interval: int = Field(..., description='Minimum polling interval in seconds')
29+
30+
31+
class DeviceTokenResponse(BaseModel):
32+
"""Response from device token endpoint (success case)."""
33+
34+
access_token: str = Field(..., description='Access token for authenticated requests')
35+
refresh_token: str = Field(..., description='Refresh token for renewing access')
36+
user_id: str = Field(..., description='Authenticated user ID')
37+
user_accounts: list[str] = Field(..., description='List of user accounts/wallets')
38+
expires_in: int = Field(..., description='Seconds until token expires')
39+
40+
41+
class DeviceTokenPendingResponse(BaseModel):
42+
"""Response when authorization is still pending."""
43+
44+
error: str = Field('authorization_pending', description='Error code')
45+
46+
47+
class DeviceTokenExpiredResponse(BaseModel):
48+
"""Response when device code has expired."""
49+
50+
error: str = Field('expired_token', description='Error code')
51+
52+
53+
def generate_pkce_pair() -> Tuple[str, str]:
54+
"""Generate PKCE code_verifier and code_challenge.
55+
56+
Returns:
57+
Tuple of (code_verifier, code_challenge)
58+
"""
59+
# Generate cryptographically random code_verifier
60+
# Must be 43-128 characters using unreserved chars [A-Za-z0-9-._~]
61+
code_verifier_bytes = secrets.token_bytes(32)
62+
code_verifier = base64.urlsafe_b64encode(code_verifier_bytes).decode('utf-8').rstrip('=')
63+
64+
# Generate code_challenge = BASE64URL(SHA256(code_verifier))
65+
challenge_bytes = hashlib.sha256(code_verifier.encode('utf-8')).digest()
66+
code_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=')
67+
68+
return code_verifier, code_challenge
69+
70+
71+
def request_device_authorization(http_client: httpx.Client) -> Tuple[DeviceAuthorizationResponse, str]:
72+
"""Request device authorization from auth platform.
73+
74+
Args:
75+
http_client: HTTP client to use for request
76+
77+
Returns:
78+
Tuple of (DeviceAuthorizationResponse, code_verifier)
79+
80+
Raises:
81+
httpx.HTTPStatusError: If request fails
82+
ValueError: If response is invalid
83+
"""
84+
# Generate PKCE parameters
85+
code_verifier, code_challenge = generate_pkce_pair()
86+
87+
# Request device authorization
88+
url = f'{AUTH_PLATFORM_URL}api/v1/device/authorize'
89+
response = http_client.post(
90+
url, json={'code_challenge': code_challenge, 'code_challenge_method': 'S256'}, timeout=30.0
91+
)
92+
93+
if response.status_code != 200:
94+
raise ValueError(f'Device authorization failed: {response.status_code} - {response.text}')
95+
96+
device_auth = DeviceAuthorizationResponse.model_validate(response.json())
97+
return device_auth, code_verifier
98+
99+
100+
def poll_for_token(http_client: httpx.Client, device_code: str, code_verifier: str) -> Optional[DeviceTokenResponse]:
101+
"""Poll device token endpoint once.
102+
103+
Args:
104+
http_client: HTTP client to use for request
105+
device_code: Device code from authorization response
106+
code_verifier: PKCE code verifier
107+
108+
Returns:
109+
DeviceTokenResponse if auth complete, None if still pending
110+
111+
Raises:
112+
ValueError: If device code expired or other error
113+
"""
114+
url = f'{AUTH_PLATFORM_URL}api/v1/device/token'
115+
response = http_client.get(url, params={'device_code': device_code, 'code_verifier': code_verifier}, timeout=10.0)
116+
117+
data = response.json()
118+
119+
# Check for error responses
120+
if 'error' in data:
121+
error = data['error']
122+
if error == 'authorization_pending':
123+
return None # Still pending
124+
elif error == 'expired_token':
125+
raise ValueError('Device code expired. Please try again.')
126+
else:
127+
raise ValueError(f'Token polling error: {error}')
128+
129+
# Success - parse token response
130+
return DeviceTokenResponse.model_validate(data)
131+
132+
133+
def poll_until_authenticated(
134+
http_client: httpx.Client,
135+
device_code: str,
136+
code_verifier: str,
137+
interval: int,
138+
expires_in: int,
139+
on_poll: Optional[callable] = None,
140+
verbose: bool = False,
141+
) -> DeviceTokenResponse:
142+
"""Poll for token until authenticated or timeout.
143+
144+
Args:
145+
http_client: HTTP client to use for requests
146+
device_code: Device code from authorization
147+
code_verifier: PKCE code verifier
148+
interval: Minimum polling interval in seconds
149+
expires_in: Seconds until device code expires
150+
on_poll: Optional callback called on each poll attempt
151+
152+
Returns:
153+
DeviceTokenResponse when authentication completes
154+
155+
Raises:
156+
ValueError: If authentication times out or fails
157+
"""
158+
start_time = time.time()
159+
poll_count = 0
160+
max_polls = int(expires_in / interval) + 5 # Add some buffer
161+
162+
while poll_count < max_polls:
163+
elapsed = time.time() - start_time
164+
if elapsed > expires_in:
165+
raise ValueError('Authentication timed out. Please try again.')
166+
167+
if on_poll:
168+
on_poll(poll_count, elapsed)
169+
170+
# Poll for token
171+
try:
172+
token_response = poll_for_token(http_client, device_code, code_verifier)
173+
if token_response:
174+
return token_response
175+
except ValueError as e:
176+
if 'expired' in str(e).lower():
177+
raise
178+
# Other errors, log and continue polling
179+
if verbose:
180+
print(f'\n⚠ Polling error (will retry): {e}')
181+
pass
182+
except Exception as e:
183+
# Log unexpected errors
184+
if verbose:
185+
print(f'\n⚠ Unexpected error (will retry): {type(e).__name__}: {e}')
186+
pass
187+
188+
# Wait before next poll
189+
time.sleep(interval)
190+
poll_count += 1
191+
192+
raise ValueError('Authentication timed out. Please try again.')
193+
194+
195+
def open_browser(url: str) -> bool:
196+
"""Open URL in browser.
197+
198+
Args:
199+
url: URL to open
200+
201+
Returns:
202+
True if browser opened successfully
203+
"""
204+
try:
205+
webbrowser.open(url)
206+
return True
207+
except Exception:
208+
return False
209+
210+
211+
def interactive_device_login(verbose: bool = True, auto_open_browser: bool = True) -> AuthStorage:
212+
"""Perform interactive device authorization flow.
213+
214+
Args:
215+
verbose: Print progress messages
216+
auto_open_browser: Automatically open browser for user
217+
218+
Returns:
219+
AuthStorage with tokens
220+
221+
Raises:
222+
ValueError: If authentication fails
223+
"""
224+
http_client = httpx.Client()
225+
226+
try:
227+
# Step 1: Request device authorization
228+
if verbose:
229+
print('🔐 Starting authentication...\n')
230+
231+
device_auth, code_verifier = request_device_authorization(http_client)
232+
233+
# Step 2: Display user code and open browser
234+
if verbose:
235+
print(f'📱 Verification Code: {device_auth.user_code}')
236+
print(f'🌐 Verification URL: {device_auth.verification_uri}\n')
237+
238+
if auto_open_browser:
239+
if verbose:
240+
print('Opening browser...')
241+
if open_browser(device_auth.verification_uri):
242+
if verbose:
243+
print('✓ Browser opened')
244+
else:
245+
if verbose:
246+
print('✗ Could not open browser automatically')
247+
print(f' Please open: {device_auth.verification_uri}')
248+
249+
if verbose:
250+
print(f'\n⏳ Waiting for authentication (expires in {device_auth.expires_in}s)...')
251+
print(' Complete the authentication in your browser.\n')
252+
253+
# Step 3: Poll for token
254+
spinner_frames = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
255+
256+
def poll_callback(count: int, elapsed: float):
257+
if verbose:
258+
spinner = spinner_frames[count % len(spinner_frames)]
259+
print(f'\r{spinner} Polling... ({int(elapsed)}s elapsed)', end='', flush=True)
260+
261+
token_response = poll_until_authenticated(
262+
http_client,
263+
device_auth.device_code,
264+
code_verifier,
265+
device_auth.interval,
266+
device_auth.expires_in,
267+
poll_callback,
268+
verbose=verbose,
269+
)
270+
271+
if verbose:
272+
print('\r✓ Authentication successful! \n')
273+
274+
# Step 4: Create auth storage
275+
now_ms = int(time.time() * 1000)
276+
expiry_ms = now_ms + (token_response.expires_in * 1000)
277+
278+
auth_storage = AuthStorage(
279+
accessToken=token_response.access_token,
280+
refreshToken=token_response.refresh_token,
281+
userId=token_response.user_id,
282+
accounts=token_response.user_accounts,
283+
expiry=expiry_ms,
284+
)
285+
286+
return auth_storage
287+
288+
finally:
289+
http_client.close()

src/amp/auth/service.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,7 @@ def refresh_token(self, auth: AuthStorage) -> AuthStorage:
196196

197197
# Validate user ID matches (security check)
198198
if refresh_response.user.id != auth.userId:
199-
raise ValueError(
200-
f'User ID mismatch after refresh. Expected {auth.userId}, got {refresh_response.user.id}'
201-
)
199+
raise ValueError(f'User ID mismatch after refresh. Expected {auth.userId}, got {refresh_response.user.id}')
202200

203201
# Calculate new expiry
204202
now_ms = int(time.time() * 1000)
@@ -225,3 +223,32 @@ def __enter__(self):
225223
def __exit__(self, exc_type, exc_val, exc_tb):
226224
"""Context manager exit."""
227225
self._http.close()
226+
227+
def login(self, verbose: bool = True, auto_open_browser: bool = True) -> None:
228+
"""Perform interactive browser-based login.
229+
230+
Opens browser for OAuth2 device authorization flow with PKCE.
231+
Saves authentication tokens to ~/.amp-cli-config/amp_cli_auth.
232+
233+
Args:
234+
verbose: Print progress messages
235+
auto_open_browser: Automatically open browser
236+
237+
Raises:
238+
ValueError: If authentication fails
239+
240+
Example:
241+
>>> auth = AuthService()
242+
>>> auth.login() # Opens browser for authentication
243+
>>> # Auth tokens saved to ~/.amp-cli-config/amp_cli_auth
244+
"""
245+
from .device_flow import interactive_device_login
246+
247+
# Perform device authorization flow
248+
auth_storage = interactive_device_login(verbose=verbose, auto_open_browser=auto_open_browser)
249+
250+
# Save to config file
251+
self.save_auth(auth_storage)
252+
253+
if verbose:
254+
print(f'✓ Authentication saved to {self.config_path}')

0 commit comments

Comments
 (0)