Skip to content

Commit 7b86e07

Browse files
kraftpqianl15
andauthored
Database Wizard (#176)
This PR adds a "wizard" that guides new users through connecting to Postgres. If a Postgres connection with default parameters can't be established, it first tries to launch Postgres using Docker then attempts to connect to a Postgres database on DBOS Cloud. --------- Co-authored-by: Qian Li <[email protected]>
1 parent d04a9ba commit 7b86e07

File tree

10 files changed

+1053
-61
lines changed

10 files changed

+1053
-61
lines changed

dbos/_cloudutils/authentication.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import os
2+
import time
3+
from dataclasses import dataclass
4+
from typing import Any, Dict, Optional
5+
6+
import jwt
7+
import requests
8+
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
9+
from cryptography.x509 import load_pem_x509_certificate
10+
from rich import print
11+
12+
from .._logger import dbos_logger
13+
14+
# Constants
15+
DBOS_CLOUD_HOST = os.getenv("DBOS_DOMAIN", "cloud.dbos.dev")
16+
PRODUCTION_ENVIRONMENT = DBOS_CLOUD_HOST == "cloud.dbos.dev"
17+
AUTH0_DOMAIN = "login.dbos.dev" if PRODUCTION_ENVIRONMENT else "dbos-inc.us.auth0.com"
18+
DBOS_CLIENT_ID = (
19+
"6p7Sjxf13cyLMkdwn14MxlH7JdhILled"
20+
if PRODUCTION_ENVIRONMENT
21+
else "G38fLmVErczEo9ioCFjVIHea6yd0qMZu"
22+
)
23+
DBOS_CLOUD_IDENTIFIER = "dbos-cloud-api"
24+
25+
26+
@dataclass
27+
class DeviceCodeResponse:
28+
device_code: str
29+
user_code: str
30+
verification_uri: str
31+
verification_uri_complete: str
32+
expires_in: int
33+
interval: int
34+
35+
@classmethod
36+
def from_dict(cls, data: Dict[str, Any]) -> "DeviceCodeResponse":
37+
return cls(
38+
device_code=data["device_code"],
39+
user_code=data["user_code"],
40+
verification_uri=data["verification_uri"],
41+
verification_uri_complete=data["verification_uri_complete"],
42+
expires_in=data["expires_in"],
43+
interval=data["interval"],
44+
)
45+
46+
47+
@dataclass
48+
class TokenResponse:
49+
access_token: str
50+
token_type: str
51+
expires_in: int
52+
refresh_token: Optional[str] = None
53+
54+
@classmethod
55+
def from_dict(cls, data: Dict[str, Any]) -> "TokenResponse":
56+
return cls(
57+
access_token=data["access_token"],
58+
token_type=data["token_type"],
59+
expires_in=data["expires_in"],
60+
refresh_token=data.get("refresh_token"),
61+
)
62+
63+
64+
@dataclass
65+
class AuthenticationResponse:
66+
token: str
67+
refresh_token: Optional[str] = None
68+
69+
70+
class JWKSClient:
71+
def __init__(self, jwks_uri: str):
72+
self.jwks_uri = jwks_uri
73+
74+
def get_signing_key(self, kid: str) -> RSAPublicKey:
75+
response = requests.get(self.jwks_uri)
76+
jwks = response.json()
77+
for key in jwks["keys"]:
78+
if key["kid"] == kid:
79+
cert_text = f"-----BEGIN CERTIFICATE-----\n{key['x5c'][0]}\n-----END CERTIFICATE-----"
80+
cert = load_pem_x509_certificate(cert_text.encode())
81+
return cert.public_key() # type: ignore
82+
raise Exception(f"Unable to find signing key with kid: {kid}")
83+
84+
85+
def verify_token(token: str) -> None:
86+
header = jwt.get_unverified_header(token)
87+
88+
if not header.get("kid"):
89+
raise ValueError("Invalid token: No 'kid' in header")
90+
91+
client = JWKSClient(f"https://{AUTH0_DOMAIN}/.well-known/jwks.json")
92+
signing_key = client.get_signing_key(header["kid"])
93+
jwt.decode(
94+
token,
95+
signing_key,
96+
algorithms=["RS256"],
97+
audience=DBOS_CLOUD_IDENTIFIER,
98+
options={
99+
"verify_iat": False,
100+
"clock_tolerance": 60,
101+
},
102+
)
103+
104+
105+
def authenticate(get_refresh_token: bool = False) -> Optional[AuthenticationResponse]:
106+
print(
107+
"[bold blue]Please authenticate with DBOS Cloud to access a Postgres database[/bold blue]"
108+
)
109+
110+
# Get device code
111+
device_code_data = {
112+
"client_id": DBOS_CLIENT_ID,
113+
"scope": "offline_access" if get_refresh_token else "sub",
114+
"audience": DBOS_CLOUD_IDENTIFIER,
115+
}
116+
117+
try:
118+
response = requests.post(
119+
f"https://{AUTH0_DOMAIN}/oauth/device/code",
120+
data=device_code_data,
121+
headers={"content-type": "application/x-www-form-urlencoded"},
122+
)
123+
device_code_response = DeviceCodeResponse.from_dict(response.json())
124+
except Exception as e:
125+
dbos_logger.error(f"Failed to log in: {str(e)}")
126+
return None
127+
128+
login_url = device_code_response.verification_uri_complete
129+
print(f"[bold blue]Login URL:[/bold blue] {login_url}")
130+
131+
# Poll for token
132+
token_data = {
133+
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
134+
"device_code": device_code_response.device_code,
135+
"client_id": DBOS_CLIENT_ID,
136+
}
137+
138+
elapsed_time_sec = 0
139+
token_response = None
140+
141+
while elapsed_time_sec < device_code_response.expires_in:
142+
try:
143+
time.sleep(device_code_response.interval)
144+
elapsed_time_sec += device_code_response.interval
145+
146+
response = requests.post(
147+
f"https://{AUTH0_DOMAIN}/oauth/token",
148+
data=token_data,
149+
headers={"content-type": "application/x-www-form-urlencoded"},
150+
)
151+
if response.status_code == 200:
152+
token_response = TokenResponse.from_dict(response.json())
153+
break
154+
except Exception:
155+
dbos_logger.info("Waiting for login...")
156+
157+
if not token_response:
158+
return None
159+
160+
verify_token(token_response.access_token)
161+
return AuthenticationResponse(
162+
token=token_response.access_token, refresh_token=token_response.refresh_token
163+
)

0 commit comments

Comments
 (0)