88import asyncio
99from typing import Any , Dict , List , Optional , Tuple
1010from urllib .parse import urlencode
11+ from carbonserver .config import settings
1112
1213import httpx
13- from fastapi_oidc import discovery
14+ from fastapi_oidc import discovery , get_auth
1415from jose import jwt
1516
1617DEFAULT_SIGNATURE_CACHE_TTL = 3600 # seconds
17-
18+ OAUTH_SCOPES = ["openid" , "email" , "profile" ]
19+
20+ from authlib .integrations .starlette_client import OAuth
21+ oauth = OAuth ()
22+ oauth .register (
23+ "client" ,
24+ client_id = settings .oidc_client_id ,
25+ client_secret = settings .oidc_client_secret ,
26+ server_metadata_url = settings .oidc_well_known_url ,
27+ client_kwargs = {"scope" : "openid profile email" },
28+ )
1829
1930class OIDCAuthProvider :
20- """
21- Generic OIDC authentication provider implementation.
22-
23- This class uses OIDC discovery and validation (via fastapi-oidc) to interact with
24- any OIDC-compliant authentication server (such as Fief, Keycloak, Auth0, etc.).
25- """
26-
2731 def __init__ (
2832 self ,
2933 base_url : str ,
@@ -33,191 +37,10 @@ def __init__(
3337 signature_cache_ttl : int = DEFAULT_SIGNATURE_CACHE_TTL ,
3438 openid_configuration : Optional [Dict [str , Any ]] = None ,
3539 ):
36- """
37- Initialize the OIDC authentication provider.
38-
39- Args:
40- base_url: The OIDC issuer URL (base URL of the authentication server)
41- client_id: The OAuth2 client ID
42- client_secret: The OAuth2 client secret
43- signature_cache_ttl: Seconds to cache the OIDC discovery/JWKS responses
44- openid_configuration: Optional pre-loaded OIDC configuration (used mainly for testing)
45- """
46- self .base_url = base_url .rstrip ("/" )
47- self .client_id = client_id
48- self .client_secret = client_secret
49- self ._discovery = discovery .configure (cache_ttl = signature_cache_ttl )
50- self ._openid_configuration = openid_configuration
51-
52- async def _get_openid_configuration (self ) -> Dict [str , Any ]:
53- if self ._openid_configuration is None :
54- self ._openid_configuration = await asyncio .to_thread (
55- self ._discovery .auth_server , base_url = self .base_url
56- )
57- return self ._openid_configuration
58-
59- async def _get_jwks (self ) -> Dict [str , Any ]:
60- oidc_config = await self ._get_openid_configuration ()
61- return await asyncio .to_thread (self ._discovery .public_keys , oidc_config )
62-
63- async def _get_algorithms (self ) -> List [str ]:
64- oidc_config = await self ._get_openid_configuration ()
65- return await asyncio .to_thread (self ._discovery .signing_algos , oidc_config )
66-
67- async def _decode_token (self , token : str ) -> Dict [str , Any ]:
68- oidc_config = await self ._get_openid_configuration ()
69- jwks = await self ._get_jwks ()
70- algorithms = await self ._get_algorithms ()
71- return jwt .decode (
72- token ,
73- jwks ,
74- algorithms = algorithms ,
75- issuer = oidc_config .get ("issuer" , self .base_url ),
76- options = {"verify_aud" : False , "verify_at_hash" : False },
77- )
78-
79- async def get_auth_url (
80- self , redirect_uri : str , scope : List [str ], state : Optional [str ] = None
81- ) -> str :
82- """
83- Generate the authorization URL for the OAuth2 flow.
84-
85- Args:
86- redirect_uri: The URI to redirect to after authentication
87- scope: List of OAuth2 scopes to request
88- state: Optional state parameter for CSRF protection
89-
90- Returns:
91- The authorization URL to redirect the user to
92- """
93- oidc_config = await self ._get_openid_configuration ()
94- authorize_endpoint = oidc_config .get (
95- "authorization_endpoint" , f"{ self .base_url } /authorize"
96- )
97- params = {
98- "response_type" : "code" ,
99- "client_id" : self .client_id ,
100- "redirect_uri" : redirect_uri ,
101- "scope" : " " .join (scope ),
102- }
103- if state is not None :
104- params ["state" ] = state
105-
106- return f"{ authorize_endpoint } ?{ urlencode (params )} "
107-
108- async def handle_auth_callback (
109- self , code : str , redirect_uri : str
110- ) -> Tuple [Dict [str , Any ], Optional [Dict [str , Any ]]]:
111- """
112- Handle the OAuth2 callback and exchange the code for tokens.
113-
114- Args:
115- code: The authorization code from the OAuth2 provider
116- redirect_uri: The redirect URI used in the initial auth request
40+ self .client = oauth ._clients ["client" ]
11741
118- Returns:
119- A tuple of (tokens, user_info) where:
120- - tokens: Dict containing access_token, refresh_token, expires_in, etc.
121- - user_info: Optional dict containing user information
122- """
123- oidc_config = await self ._get_openid_configuration ()
124- token_endpoint = oidc_config .get ("token_endpoint" , f"{ self .base_url } /api/token" )
125- async with httpx .AsyncClient () as client :
126- response = await client .post (
127- token_endpoint ,
128- data = {
129- "grant_type" : "authorization_code" ,
130- "code" : code ,
131- "redirect_uri" : redirect_uri ,
132- "client_id" : self .client_id ,
133- "client_secret" : self .client_secret ,
134- },
135- headers = {"accept" : "application/json" },
136- )
137- response .raise_for_status ()
138- tokens : Dict [str , Any ] = response .json ()
139-
140- user_info : Optional [Dict [str , Any ]] = None
141- if "id_token" in tokens :
142- user_info = await self ._decode_token (tokens ["id_token" ])
143- elif "access_token" in tokens :
144- try :
145- user_info = await self .get_user_info (tokens ["access_token" ])
146- except Exception :
147- # If userinfo fails we still return tokens
148- user_info = None
149-
150- return (tokens , user_info )
151-
152- async def validate_access_token (self , token : str ) -> bool :
153- """
154- Validate an access token.
155-
156- Args:
157- token: The access token to validate
158-
159- Returns:
160- True if the token is valid
161-
162- Raises:
163- Exception if validation fails
164- """
165- await self ._decode_token (token )
166- return True
167-
168- async def get_user_info (self , access_token : str ) -> Dict [str , Any ]:
169- """
170- Get user information from the OIDC provider.
171-
172- Args:
173- access_token: The access token for the user
174-
175- Returns:
176- Dict containing user information (sub, email, name, etc.)
177- """
178- oidc_config = await self ._get_openid_configuration ()
179- userinfo_endpoint = oidc_config .get (
180- "userinfo_endpoint" , f"{ self .base_url } /api/userinfo"
181- )
182- headers = {"Authorization" : f"Bearer { access_token } " }
183- async with httpx .AsyncClient () as client :
184- response = await client .get (userinfo_endpoint , headers = headers )
185- response .raise_for_status ()
186- return response .json ()
187-
188- def get_token_endpoint (self ) -> str :
189- """
190- Get the token endpoint URL.
191-
192- Returns:
193- The token endpoint URL
194- """
195- if (
196- self ._openid_configuration
197- and "token_endpoint" in self ._openid_configuration
198- ):
199- return self ._openid_configuration ["token_endpoint" ]
200- return f"{ self .base_url } /api/token"
201-
202- def get_authorize_endpoint (self ) -> str :
203- """
204- Get the authorization endpoint URL.
205-
206- Returns:
207- The authorization endpoint URL
208- """
209- if (
210- self ._openid_configuration
211- and "authorization_endpoint" in self ._openid_configuration
212- ):
213- return self ._openid_configuration ["authorization_endpoint" ]
214- return f"{ self .base_url } /authorize"
42+ async def get_authorize_url (self , request , login_url ):
43+ return await self .client .authorize_redirect (request , str (login_url ), scope = ' ' .join (OAUTH_SCOPES ))
21544
21645 def get_client_credentials (self ) -> Tuple [str , str ]:
217- """
218- Get the client ID and client secret.
219-
220- Returns:
221- A tuple of (client_id, client_secret)
222- """
223- return (self .client_id , self .client_secret )
46+ return (self .client .client_id , self .client .client_secret )
0 commit comments