|
| 1 | +import os |
| 2 | +import time |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import Any, Dict, Optional |
| 5 | + |
| 6 | +import jwt |
| 7 | +import requests |
| 8 | +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey |
| 9 | +from cryptography.x509 import load_pem_x509_certificate |
| 10 | +from rich import print |
| 11 | + |
| 12 | +from .._logger import dbos_logger |
| 13 | + |
| 14 | +# Constants |
| 15 | +DBOS_CLOUD_HOST = os.getenv("DBOS_DOMAIN", "cloud.dbos.dev") |
| 16 | +PRODUCTION_ENVIRONMENT = DBOS_CLOUD_HOST == "cloud.dbos.dev" |
| 17 | +AUTH0_DOMAIN = "login.dbos.dev" if PRODUCTION_ENVIRONMENT else "dbos-inc.us.auth0.com" |
| 18 | +DBOS_CLIENT_ID = ( |
| 19 | + "6p7Sjxf13cyLMkdwn14MxlH7JdhILled" |
| 20 | + if PRODUCTION_ENVIRONMENT |
| 21 | + else "G38fLmVErczEo9ioCFjVIHea6yd0qMZu" |
| 22 | +) |
| 23 | +DBOS_CLOUD_IDENTIFIER = "dbos-cloud-api" |
| 24 | + |
| 25 | + |
| 26 | +@dataclass |
| 27 | +class DeviceCodeResponse: |
| 28 | + device_code: str |
| 29 | + user_code: str |
| 30 | + verification_uri: str |
| 31 | + verification_uri_complete: str |
| 32 | + expires_in: int |
| 33 | + interval: int |
| 34 | + |
| 35 | + @classmethod |
| 36 | + def from_dict(cls, data: Dict[str, Any]) -> "DeviceCodeResponse": |
| 37 | + return cls( |
| 38 | + device_code=data["device_code"], |
| 39 | + user_code=data["user_code"], |
| 40 | + verification_uri=data["verification_uri"], |
| 41 | + verification_uri_complete=data["verification_uri_complete"], |
| 42 | + expires_in=data["expires_in"], |
| 43 | + interval=data["interval"], |
| 44 | + ) |
| 45 | + |
| 46 | + |
| 47 | +@dataclass |
| 48 | +class TokenResponse: |
| 49 | + access_token: str |
| 50 | + token_type: str |
| 51 | + expires_in: int |
| 52 | + refresh_token: Optional[str] = None |
| 53 | + |
| 54 | + @classmethod |
| 55 | + def from_dict(cls, data: Dict[str, Any]) -> "TokenResponse": |
| 56 | + return cls( |
| 57 | + access_token=data["access_token"], |
| 58 | + token_type=data["token_type"], |
| 59 | + expires_in=data["expires_in"], |
| 60 | + refresh_token=data.get("refresh_token"), |
| 61 | + ) |
| 62 | + |
| 63 | + |
| 64 | +@dataclass |
| 65 | +class AuthenticationResponse: |
| 66 | + token: str |
| 67 | + refresh_token: Optional[str] = None |
| 68 | + |
| 69 | + |
| 70 | +class JWKSClient: |
| 71 | + def __init__(self, jwks_uri: str): |
| 72 | + self.jwks_uri = jwks_uri |
| 73 | + |
| 74 | + def get_signing_key(self, kid: str) -> RSAPublicKey: |
| 75 | + response = requests.get(self.jwks_uri) |
| 76 | + jwks = response.json() |
| 77 | + for key in jwks["keys"]: |
| 78 | + if key["kid"] == kid: |
| 79 | + cert_text = f"-----BEGIN CERTIFICATE-----\n{key['x5c'][0]}\n-----END CERTIFICATE-----" |
| 80 | + cert = load_pem_x509_certificate(cert_text.encode()) |
| 81 | + return cert.public_key() # type: ignore |
| 82 | + raise Exception(f"Unable to find signing key with kid: {kid}") |
| 83 | + |
| 84 | + |
| 85 | +def verify_token(token: str) -> None: |
| 86 | + header = jwt.get_unverified_header(token) |
| 87 | + |
| 88 | + if not header.get("kid"): |
| 89 | + raise ValueError("Invalid token: No 'kid' in header") |
| 90 | + |
| 91 | + client = JWKSClient(f"https://{AUTH0_DOMAIN}/.well-known/jwks.json") |
| 92 | + signing_key = client.get_signing_key(header["kid"]) |
| 93 | + jwt.decode( |
| 94 | + token, |
| 95 | + signing_key, |
| 96 | + algorithms=["RS256"], |
| 97 | + audience=DBOS_CLOUD_IDENTIFIER, |
| 98 | + options={ |
| 99 | + "verify_iat": False, |
| 100 | + "clock_tolerance": 60, |
| 101 | + }, |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +def authenticate(get_refresh_token: bool = False) -> Optional[AuthenticationResponse]: |
| 106 | + print( |
| 107 | + "[bold blue]Please authenticate with DBOS Cloud to access a Postgres database[/bold blue]" |
| 108 | + ) |
| 109 | + |
| 110 | + # Get device code |
| 111 | + device_code_data = { |
| 112 | + "client_id": DBOS_CLIENT_ID, |
| 113 | + "scope": "offline_access" if get_refresh_token else "sub", |
| 114 | + "audience": DBOS_CLOUD_IDENTIFIER, |
| 115 | + } |
| 116 | + |
| 117 | + try: |
| 118 | + response = requests.post( |
| 119 | + f"https://{AUTH0_DOMAIN}/oauth/device/code", |
| 120 | + data=device_code_data, |
| 121 | + headers={"content-type": "application/x-www-form-urlencoded"}, |
| 122 | + ) |
| 123 | + device_code_response = DeviceCodeResponse.from_dict(response.json()) |
| 124 | + except Exception as e: |
| 125 | + dbos_logger.error(f"Failed to log in: {str(e)}") |
| 126 | + return None |
| 127 | + |
| 128 | + login_url = device_code_response.verification_uri_complete |
| 129 | + print(f"[bold blue]Login URL:[/bold blue] {login_url}") |
| 130 | + |
| 131 | + # Poll for token |
| 132 | + token_data = { |
| 133 | + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", |
| 134 | + "device_code": device_code_response.device_code, |
| 135 | + "client_id": DBOS_CLIENT_ID, |
| 136 | + } |
| 137 | + |
| 138 | + elapsed_time_sec = 0 |
| 139 | + token_response = None |
| 140 | + |
| 141 | + while elapsed_time_sec < device_code_response.expires_in: |
| 142 | + try: |
| 143 | + time.sleep(device_code_response.interval) |
| 144 | + elapsed_time_sec += device_code_response.interval |
| 145 | + |
| 146 | + response = requests.post( |
| 147 | + f"https://{AUTH0_DOMAIN}/oauth/token", |
| 148 | + data=token_data, |
| 149 | + headers={"content-type": "application/x-www-form-urlencoded"}, |
| 150 | + ) |
| 151 | + if response.status_code == 200: |
| 152 | + token_response = TokenResponse.from_dict(response.json()) |
| 153 | + break |
| 154 | + except Exception: |
| 155 | + dbos_logger.info("Waiting for login...") |
| 156 | + |
| 157 | + if not token_response: |
| 158 | + return None |
| 159 | + |
| 160 | + verify_token(token_response.access_token) |
| 161 | + return AuthenticationResponse( |
| 162 | + token=token_response.access_token, refresh_token=token_response.refresh_token |
| 163 | + ) |
0 commit comments