22from fastapi .responses import RedirectResponse
33from typing import Optional
44from ..auth .auth_client import AuthClient
5- from ..config import Auth0Config
6- from ..util import to_safe_redirect , create_route_url , merge_set_cookie_headers
5+ from ..config import Auth0Config
6+ from ..util import to_safe_redirect , create_route_url
77
88router = APIRouter ()
99
10+
1011def get_auth_client (request : Request ) -> AuthClient :
1112 """
1213 Dependency function to retrieve the AuthClient instance.
1314 Assumes the client is set on the FastAPI application state.
1415 """
1516 auth_client = request .app .state .auth_client
1617 if not auth_client :
17- raise HTTPException (status_code = 500 , detail = "Authentication client not configured." )
18+ raise HTTPException (
19+ status_code = 500 , detail = "Authentication client not configured." )
1820 return auth_client
1921
22+
2023def register_auth_routes (router : APIRouter , config : Auth0Config ):
2124 """
2225 Conditionally register auth routes based on config.mount_routes and config.mount_connect_routes.
@@ -29,16 +32,14 @@ async def login(request: Request, response: Response, auth_client: AuthClient =
2932 Optionally accepts a 'return_to' query parameter and passes it as part of the app state.
3033 Redirects the user to the Auth0 authorization URL.
3134 """
32-
35+
3336 return_to : Optional [str ] = request .query_params .get ("returnTo" )
34- auth_url = await auth_client .start_login (
37+ auth_url = await auth_client .start_login (
3538 app_state = {"returnTo" : return_to } if return_to else None ,
3639 store_options = {"response" : response }
3740 )
3841
39- redirect_response = RedirectResponse (url = auth_url )
40-
41- return merge_set_cookie_headers (response , redirect_response )
42+ return RedirectResponse (url = auth_url , headers = response .headers )
4243
4344 @router .get ("/auth/callback" )
4445 async def callback (request : Request , response : Response , auth_client : AuthClient = Depends (get_auth_client )):
@@ -52,15 +53,14 @@ async def callback(request: Request, response: Response, auth_client: AuthClient
5253 session_data = await auth_client .complete_login (full_callback_url , store_options = {"request" : request , "response" : response })
5354 except Exception as e :
5455 raise HTTPException (status_code = 400 , detail = str (e ))
55-
56+
5657 # Extract the returnTo URL from the appState if available.
5758 return_to = session_data .get ("app_state" , {}).get ("returnTo" )
58-
59- default_redirect = auth_client .config .app_base_url # Assuming config is stored on app.state
60-
61- redirect_response = RedirectResponse (url = return_to or default_redirect )
6259
63- return merge_set_cookie_headers (response , redirect_response )
60+ # Assuming config is stored on app.state
61+ default_redirect = auth_client .config .app_base_url
62+
63+ return RedirectResponse (url = return_to or default_redirect , headers = response .headers )
6464
6565 @router .get ("/auth/logout" )
6666 async def logout (request : Request , response : Response , auth_client : AuthClient = Depends (get_auth_client )):
@@ -75,10 +75,8 @@ async def logout(request: Request, response: Response, auth_client: AuthClient =
7575 logout_url = await auth_client .logout (return_to = return_to or default_redirect , store_options = {"response" : response })
7676 except Exception as e :
7777 raise HTTPException (status_code = 500 , detail = str (e ))
78-
79- redirect_response = RedirectResponse (url = logout_url )
80-
81- return merge_set_cookie_headers (response , redirect_response )
78+
79+ return RedirectResponse (url = logout_url , headers = response .headers )
8280
8381 @router .post ("/auth/backchannel-logout" )
8482 async def backchannel_logout (request : Request , auth_client : AuthClient = Depends (get_auth_client )):
@@ -90,17 +88,18 @@ async def backchannel_logout(request: Request, auth_client: AuthClient = Depends
9088 body = await request .json ()
9189 logout_token = body .get ("logout_token" )
9290 if not logout_token :
93- raise HTTPException (status_code = 400 , detail = "Missing 'logout_token' in request body." )
94-
91+ raise HTTPException (
92+ status_code = 400 , detail = "Missing 'logout_token' in request body." )
93+
9594 try :
9695 await auth_client .handle_backchannel_logout (logout_token )
9796 except Exception as e :
9897 raise HTTPException (status_code = 400 , detail = str (e ))
9998 return Response (status_code = 204 )
100-
99+
101100 #################### Testing Route (Won't be there in the Fastify SDKs) ###################################
102101 @router .get ("/auth/profile" )
103- async def profile (request : Request , response :Response , auth_client : AuthClient = Depends (get_auth_client )):
102+ async def profile (request : Request , response : Response , auth_client : AuthClient = Depends (get_auth_client )):
104103 # Prepare store_options with the Request object (used by the state store to read cookies)
105104 store_options = {"request" : request , "response" : response }
106105 try :
@@ -109,21 +108,20 @@ async def profile(request: Request, response:Response, auth_client: AuthClient =
109108 session = await auth_client .client .get_session (store_options = store_options )
110109 except Exception as e :
111110 raise HTTPException (status_code = 400 , detail = str (e ))
112-
111+
113112 return {
114113 "user" : user ,
115114 "session" : session
116115 }
117116
118-
119117 @router .get ("/auth/token" )
120118 async def get_token (request : Request , response : Response , auth_client : AuthClient = Depends (get_auth_client )):
121119 # Prepare store_options with the Request object (used by the state store to read cookies)
122120 store_options = {"request" : request , "response" : response }
123121 try :
124122 # Retrieve access token from the client
125123 access_token = await auth_client .client .get_access_token (store_options = store_options )
126-
124+
127125 return {
128126 "access_token_available" : bool (access_token ),
129127 "access_token_preview" : access_token [:10 ] + "..." if access_token else None ,
@@ -132,33 +130,32 @@ async def get_token(request: Request, response: Response, auth_client: AuthClien
132130 except Exception as e :
133131 raise HTTPException (status_code = 400 , detail = str (e ))
134132
135-
136133 @router .get ("/auth/connection/{connection_name}" )
137134 async def get_connection_token (
138135 connection_name : str ,
139- request : Request ,
140- response : Response ,
136+ request : Request ,
137+ response : Response ,
141138 auth_client : AuthClient = Depends (get_auth_client ),
142139 login_hint : Optional [str ] = None
143140 ):
144141 store_options = {"request" : request , "response" : response }
145-
142+
146143 try :
147144 # Create connection options as a dictionary
148145 connection_options = {
149146 "connection" : connection_name
150147 }
151-
148+
152149 # Add login_hint if provided
153150 if login_hint :
154151 connection_options ["login_hint" ] = login_hint
155-
152+
156153 # Retrieve connection-specific access token
157154 access_token = await auth_client .client .get_access_token_for_connection (
158- connection_options ,
155+ connection_options ,
159156 store_options = store_options
160157 )
161-
158+
162159 # Return a response with token information
163160 return {
164161 "connection" : connection_name ,
@@ -170,34 +167,37 @@ async def get_connection_token(
170167 # Handle all errors with a single exception handler
171168 raise HTTPException (status_code = 400 , detail = str (e ))
172169 #################### ********Testing Routes End ****** ###################################
173-
170+
174171 if config .mount_connect_routes :
175172
176173 @router .get ("/auth/connect" )
177- async def connect (request : Request , response : Response ,
178- connection : Optional [str ] = Query (None ),
179- connectionScope : Optional [str ] = Query (None ),
180- returnTo : Optional [str ] = Query (None ),
181- auth_client : AuthClient = Depends (get_auth_client )):
174+ async def connect (request : Request , response : Response ,
175+ connection : Optional [str ] = Query (None ),
176+ connectionScope : Optional [str ] = Query (None ),
177+ returnTo : Optional [str ] = Query (None ),
178+ auth_client : AuthClient = Depends (get_auth_client )):
182179
183180 # Extract query parameters (connection, connectionScope, returnTo)
184181 connection = connection or request .query_params .get ("connection" )
185- connection_scope = connectionScope or request .query_params .get ("connectionScope" )
186- dangerous_return_to = returnTo or request .query_params .get ("returnTo" )
187-
182+ connection_scope = connectionScope or request .query_params .get (
183+ "connectionScope" )
184+ dangerous_return_to = returnTo or request .query_params .get (
185+ "returnTo" )
188186
189187 if not connection :
190188 raise HTTPException (
191189 status_code = 400 ,
192190 detail = "connection is not set"
193191 )
194-
195- sanitized_return_to = to_safe_redirect (dangerous_return_to or "/" , auth_client .config .app_base_url )
196-
192+
193+ sanitized_return_to = to_safe_redirect (
194+ dangerous_return_to or "/" , auth_client .config .app_base_url )
195+
197196 # Create the callback URL for linking
198197 callback_path = "/auth/connect/callback"
199- redirect_uri = create_route_url (callback_path , auth_client .config .app_base_url )
200-
198+ redirect_uri = create_route_url (
199+ callback_path , auth_client .config .app_base_url )
200+
201201 # Call the startLinkUser method on our AuthClient. This method should accept parameters similar to:
202202 # connection, connectionScope, authorizationParams (with redirect_uri), and app_state.
203203 link_user_url = await auth_client .start_link_user ({
@@ -211,9 +211,7 @@ async def connect(request: Request, response: Response,
211211 }
212212 }, store_options = {"request" : request , "response" : response })
213213
214- redirect_response = RedirectResponse (url = link_user_url )
215-
216- return merge_set_cookie_headers (response , redirect_response )
214+ return RedirectResponse (url = link_user_url , headers = response .headers )
217215
218216 @router .get ("/auth/connect/callback" )
219217 async def connect_callback (request : Request , response : Response , auth_client : AuthClient = Depends (get_auth_client )):
@@ -223,43 +221,40 @@ async def connect_callback(request: Request, response: Response, auth_client: Au
223221 result = await auth_client .complete_link_user (callback_url , store_options = {"request" : request , "response" : response })
224222 except Exception as e :
225223 raise HTTPException (status_code = 400 , detail = str (e ))
226-
224+
227225 # Retrieve the returnTo parameter from app_state if available
228226 return_to = result .get ("app_state" , {}).get ("returnTo" )
229227
230228 app_base_url = auth_client .config .app_base_url
231229
232- redirect_response = RedirectResponse (url = return_to or app_base_url )
233- if "set-cookie" in response .headers :
234- cookies = response .headers .getlist ("set-cookie" ) if hasattr (response .headers , "getlist" ) else [response .headers ["set-cookie" ]]
235- for cookie in cookies :
236- redirect_response .headers .append ("set-cookie" , cookie )
237- return redirect_response
238-
230+ return RedirectResponse (url = return_to or app_base_url , headers = response .headers )
231+
239232 @router .get ("/auth/unconnect" )
240- async def connect (request : Request , response : Response ,
241- connection : Optional [str ] = Query (None ),
242- connectionScope : Optional [str ] = Query (None ),
243- returnTo : Optional [str ] = Query (None ),
244- auth_client : AuthClient = Depends (get_auth_client )):
233+ async def connect (request : Request , response : Response ,
234+ connection : Optional [str ] = Query (None ),
235+ connectionScope : Optional [str ] = Query (None ),
236+ returnTo : Optional [str ] = Query (None ),
237+ auth_client : AuthClient = Depends (get_auth_client )):
245238
246239 # Extract query parameters (connection, connectionScope, returnTo)
247240 connection = connection or request .query_params .get ("connection" )
248- dangerous_return_to = returnTo or request .query_params .get ("returnTo" )
249-
241+ dangerous_return_to = returnTo or request .query_params .get (
242+ "returnTo" )
250243
251244 if not connection :
252245 raise HTTPException (
253246 status_code = 400 ,
254247 detail = "connection is not set"
255248 )
256-
257- sanitized_return_to = to_safe_redirect (dangerous_return_to or "/" , auth_client .config .app_base_url )
258-
249+
250+ sanitized_return_to = to_safe_redirect (
251+ dangerous_return_to or "/" , auth_client .config .app_base_url )
252+
259253 # Create the callback URL for linking
260254 callback_path = "/auth/unconnect/callback"
261- redirect_uri = create_route_url (callback_path , auth_client .config .app_base_url )
262-
255+ redirect_uri = create_route_url (
256+ callback_path , auth_client .config .app_base_url )
257+
263258 # Call the startLinkUser method on our AuthClient. This method should accept parameters similar to:
264259 # connection, connectionScope, authorizationParams (with redirect_uri), and app_state.
265260 link_user_url = await auth_client .start_unlink_user ({
@@ -272,9 +267,7 @@ async def connect(request: Request, response: Response,
272267 }
273268 }, store_options = {"request" : request , "response" : response })
274269
275- redirect_response = RedirectResponse (url = link_user_url )
276-
277- return merge_set_cookie_headers (response , redirect_response )
270+ return RedirectResponse (url = link_user_url , headers = response .headers )
278271
279272 @router .get ("/auth/unconnect/callback" )
280273 async def unconnect_callback (request : Request , response : Response , auth_client : AuthClient = Depends (get_auth_client )):
@@ -284,13 +277,10 @@ async def unconnect_callback(request: Request, response: Response, auth_client:
284277 result = await auth_client .complete_unlink_user (callback_url , store_options = {"request" : request , "response" : response })
285278 except Exception as e :
286279 raise HTTPException (status_code = 400 , detail = str (e ))
287-
280+
288281 # Retrieve the returnTo parameter from appState if available
289282 return_to = result .get ("app_state" , {}).get ("returnTo" )
290283
291284 app_base_url = auth_client .config .app_base_url
292285
293- redirect_response = RedirectResponse (url = return_to or app_base_url )
294-
295- return merge_set_cookie_headers (response , redirect_response )
296-
286+ return RedirectResponse (url = return_to or app_base_url , headers = response .headers )
0 commit comments