|
| 1 | +""" |
| 2 | +Auth0 Device Authorization Flow implementation. |
| 3 | +
|
| 4 | +Follows RFC 8628: https://tools.ietf.org/html/rfc8628 |
| 5 | +""" |
| 6 | + |
| 7 | +import asyncio |
| 8 | +import time |
| 9 | +from dataclasses import dataclass |
| 10 | + |
| 11 | +import httpx |
| 12 | + |
| 13 | +from .exceptions import AuthenticationError, AuthenticationTimeout |
| 14 | + |
| 15 | + |
| 16 | +@dataclass |
| 17 | +class DeviceCodeResponse: |
| 18 | + """Response from /oauth/device/code endpoint.""" |
| 19 | + |
| 20 | + device_code: str |
| 21 | + user_code: str |
| 22 | + verification_uri: str |
| 23 | + verification_uri_complete: str |
| 24 | + expires_in: int |
| 25 | + interval: int |
| 26 | + |
| 27 | + |
| 28 | +@dataclass |
| 29 | +class TokenResponse: |
| 30 | + """Response from /oauth/token endpoint.""" |
| 31 | + |
| 32 | + access_token: str |
| 33 | + refresh_token: str | None |
| 34 | + id_token: str | None |
| 35 | + token_type: str |
| 36 | + expires_in: int |
| 37 | + scope: str |
| 38 | + |
| 39 | + |
| 40 | +class DeviceAuthFlow: |
| 41 | + """Implements Auth0 Device Authorization Flow.""" |
| 42 | + |
| 43 | + def __init__( |
| 44 | + self, |
| 45 | + domain: str, |
| 46 | + client_id: str, |
| 47 | + audience: str, |
| 48 | + scopes: str = "openid profile email offline_access", |
| 49 | + ): |
| 50 | + self.domain = domain |
| 51 | + self.client_id = client_id |
| 52 | + self.audience = audience |
| 53 | + self.scopes = scopes |
| 54 | + self.device_code_url = f"https://{domain}/oauth/device/code" |
| 55 | + self.token_url = f"https://{domain}/oauth/token" |
| 56 | + |
| 57 | + async def initiate(self) -> DeviceCodeResponse: |
| 58 | + """ |
| 59 | + Step 1: Initiate device authorization flow. |
| 60 | +
|
| 61 | + Returns device code and user instructions. |
| 62 | + """ |
| 63 | + async with httpx.AsyncClient() as client: |
| 64 | + response = await client.post( |
| 65 | + self.device_code_url, |
| 66 | + data={ # Use form data instead of JSON for better compatibility |
| 67 | + "client_id": self.client_id, |
| 68 | + "scope": self.scopes, |
| 69 | + "audience": self.audience, |
| 70 | + }, |
| 71 | + ) |
| 72 | + response.raise_for_status() |
| 73 | + data = response.json() |
| 74 | + |
| 75 | + return DeviceCodeResponse( |
| 76 | + device_code=data["device_code"], |
| 77 | + user_code=data["user_code"], |
| 78 | + verification_uri=data["verification_uri"], |
| 79 | + verification_uri_complete=data.get( |
| 80 | + "verification_uri_complete", |
| 81 | + f"{data['verification_uri']}?user_code={data['user_code']}", |
| 82 | + ), |
| 83 | + expires_in=data["expires_in"], |
| 84 | + interval=data["interval"], |
| 85 | + ) |
| 86 | + |
| 87 | + async def poll_for_token( |
| 88 | + self, device_code: str, interval: int, timeout: int = 900 |
| 89 | + ) -> TokenResponse: |
| 90 | + """ |
| 91 | + Step 2: Poll for access token. |
| 92 | +
|
| 93 | + Args: |
| 94 | + device_code: Device code from initiate() |
| 95 | + interval: Polling interval in seconds |
| 96 | + timeout: Max time to wait in seconds |
| 97 | +
|
| 98 | + Returns: |
| 99 | + Token response with access_token and refresh_token |
| 100 | +
|
| 101 | + Raises: |
| 102 | + AuthenticationTimeout: If user doesn't complete auth in time |
| 103 | + AuthenticationError: If auth fails or user denies |
| 104 | + """ |
| 105 | + start_time = time.time() |
| 106 | + current_interval = interval |
| 107 | + |
| 108 | + async with httpx.AsyncClient() as client: |
| 109 | + while time.time() - start_time < timeout: |
| 110 | + await asyncio.sleep(current_interval) |
| 111 | + |
| 112 | + try: |
| 113 | + response = await client.post( |
| 114 | + self.token_url, |
| 115 | + data={ # Use form data for better OAuth compatibility |
| 116 | + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", |
| 117 | + "device_code": device_code, |
| 118 | + "client_id": self.client_id, |
| 119 | + }, |
| 120 | + ) |
| 121 | + |
| 122 | + # Success! |
| 123 | + if response.status_code == 200: |
| 124 | + data = response.json() |
| 125 | + return TokenResponse( |
| 126 | + access_token=data["access_token"], |
| 127 | + refresh_token=data.get("refresh_token"), |
| 128 | + id_token=data.get("id_token"), |
| 129 | + token_type=data["token_type"], |
| 130 | + expires_in=data["expires_in"], |
| 131 | + scope=data.get("scope", ""), |
| 132 | + ) |
| 133 | + |
| 134 | + # Handle errors |
| 135 | + error_data = response.json() |
| 136 | + error = error_data.get("error") |
| 137 | + |
| 138 | + if error == "authorization_pending": |
| 139 | + # User hasn't completed auth yet |
| 140 | + continue |
| 141 | + elif error == "slow_down": |
| 142 | + # Increase polling interval |
| 143 | + current_interval += 5 |
| 144 | + continue |
| 145 | + elif error == "expired_token": |
| 146 | + raise AuthenticationTimeout("Device code expired") |
| 147 | + elif error == "access_denied": |
| 148 | + raise AuthenticationError("User denied authorization") |
| 149 | + else: |
| 150 | + raise AuthenticationError(f"Authentication failed: {error}") |
| 151 | + |
| 152 | + except httpx.HTTPError as e: |
| 153 | + raise AuthenticationError(f"Network error: {e}") |
| 154 | + |
| 155 | + raise AuthenticationTimeout("Authentication timeout") |
| 156 | + |
| 157 | + async def refresh_token(self, refresh_token: str) -> TokenResponse: |
| 158 | + """ |
| 159 | + Refresh an expired access token. |
| 160 | +
|
| 161 | + Args: |
| 162 | + refresh_token: Refresh token from previous authentication |
| 163 | +
|
| 164 | + Returns: |
| 165 | + New token response |
| 166 | + """ |
| 167 | + async with httpx.AsyncClient() as client: |
| 168 | + response = await client.post( |
| 169 | + self.token_url, |
| 170 | + data={ # Use form data for better OAuth compatibility |
| 171 | + "grant_type": "refresh_token", |
| 172 | + "client_id": self.client_id, |
| 173 | + "refresh_token": refresh_token, |
| 174 | + }, |
| 175 | + ) |
| 176 | + |
| 177 | + if response.status_code != 200: |
| 178 | + error_data = response.json() |
| 179 | + raise AuthenticationError( |
| 180 | + f"Token refresh failed: {error_data.get('error')}" |
| 181 | + ) |
| 182 | + |
| 183 | + data = response.json() |
| 184 | + return TokenResponse( |
| 185 | + access_token=data["access_token"], |
| 186 | + refresh_token=data.get( |
| 187 | + "refresh_token", refresh_token |
| 188 | + ), # May return same |
| 189 | + id_token=data.get("id_token"), |
| 190 | + token_type=data["token_type"], |
| 191 | + expires_in=data["expires_in"], |
| 192 | + scope=data.get("scope", ""), |
| 193 | + ) |
0 commit comments