Skip to content

Commit 8dd2d8c

Browse files
fix(util): remove unnecessary merge header cookis util (#13)
1 parent 9c1ef48 commit 8dd2d8c

File tree

2 files changed

+83
-117
lines changed

2 files changed

+83
-117
lines changed

packages/auth0_fastapi/src/auth0_fastapi/server/routes.py

Lines changed: 68 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,24 @@
22
from fastapi.responses import RedirectResponse
33
from typing import Optional
44
from ..auth.auth_client import AuthClient
5-
from ..config import Auth0Config
6-
from ..util import to_safe_redirect, create_route_url, merge_set_cookie_headers
5+
from ..config import Auth0Config
6+
from ..util import to_safe_redirect, create_route_url
77

88
router = APIRouter()
99

10+
1011
def get_auth_client(request: Request) -> AuthClient:
1112
"""
1213
Dependency function to retrieve the AuthClient instance.
1314
Assumes the client is set on the FastAPI application state.
1415
"""
1516
auth_client = request.app.state.auth_client
1617
if not auth_client:
17-
raise HTTPException(status_code=500, detail="Authentication client not configured.")
18+
raise HTTPException(
19+
status_code=500, detail="Authentication client not configured.")
1820
return auth_client
1921

22+
2023
def register_auth_routes(router: APIRouter, config: Auth0Config):
2124
"""
2225
Conditionally register auth routes based on config.mount_routes and config.mount_connect_routes.
@@ -29,16 +32,14 @@ async def login(request: Request, response: Response, auth_client: AuthClient =
2932
Optionally accepts a 'return_to' query parameter and passes it as part of the app state.
3033
Redirects the user to the Auth0 authorization URL.
3134
"""
32-
35+
3336
return_to: Optional[str] = request.query_params.get("returnTo")
34-
auth_url = await auth_client.start_login(
37+
auth_url = await auth_client.start_login(
3538
app_state={"returnTo": return_to} if return_to else None,
3639
store_options={"response": response}
3740
)
3841

39-
redirect_response = RedirectResponse(url=auth_url)
40-
41-
return merge_set_cookie_headers(response, redirect_response)
42+
return RedirectResponse(url=auth_url, headers=response.headers)
4243

4344
@router.get("/auth/callback")
4445
async def callback(request: Request, response: Response, auth_client: AuthClient = Depends(get_auth_client)):
@@ -52,15 +53,14 @@ async def callback(request: Request, response: Response, auth_client: AuthClient
5253
session_data = await auth_client.complete_login(full_callback_url, store_options={"request": request, "response": response})
5354
except Exception as e:
5455
raise HTTPException(status_code=400, detail=str(e))
55-
56+
5657
# Extract the returnTo URL from the appState if available.
5758
return_to = session_data.get("app_state", {}).get("returnTo")
58-
59-
default_redirect = auth_client.config.app_base_url # Assuming config is stored on app.state
60-
61-
redirect_response = RedirectResponse(url=return_to or default_redirect)
6259

63-
return merge_set_cookie_headers(response, redirect_response)
60+
# Assuming config is stored on app.state
61+
default_redirect = auth_client.config.app_base_url
62+
63+
return RedirectResponse(url=return_to or default_redirect, headers=response.headers)
6464

6565
@router.get("/auth/logout")
6666
async def logout(request: Request, response: Response, auth_client: AuthClient = Depends(get_auth_client)):
@@ -75,10 +75,8 @@ async def logout(request: Request, response: Response, auth_client: AuthClient =
7575
logout_url = await auth_client.logout(return_to=return_to or default_redirect, store_options={"response": response})
7676
except Exception as e:
7777
raise HTTPException(status_code=500, detail=str(e))
78-
79-
redirect_response = RedirectResponse(url=logout_url)
80-
81-
return merge_set_cookie_headers(response, redirect_response)
78+
79+
return RedirectResponse(url=logout_url, headers=response.headers)
8280

8381
@router.post("/auth/backchannel-logout")
8482
async def backchannel_logout(request: Request, auth_client: AuthClient = Depends(get_auth_client)):
@@ -90,17 +88,18 @@ async def backchannel_logout(request: Request, auth_client: AuthClient = Depends
9088
body = await request.json()
9189
logout_token = body.get("logout_token")
9290
if not logout_token:
93-
raise HTTPException(status_code=400, detail="Missing 'logout_token' in request body.")
94-
91+
raise HTTPException(
92+
status_code=400, detail="Missing 'logout_token' in request body.")
93+
9594
try:
9695
await auth_client.handle_backchannel_logout(logout_token)
9796
except Exception as e:
9897
raise HTTPException(status_code=400, detail=str(e))
9998
return Response(status_code=204)
100-
99+
101100
#################### Testing Route (Won't be there in the Fastify SDKs) ###################################
102101
@router.get("/auth/profile")
103-
async def profile(request: Request, response:Response, auth_client: AuthClient = Depends(get_auth_client)):
102+
async def profile(request: Request, response: Response, auth_client: AuthClient = Depends(get_auth_client)):
104103
# Prepare store_options with the Request object (used by the state store to read cookies)
105104
store_options = {"request": request, "response": response}
106105
try:
@@ -109,21 +108,20 @@ async def profile(request: Request, response:Response, auth_client: AuthClient =
109108
session = await auth_client.client.get_session(store_options=store_options)
110109
except Exception as e:
111110
raise HTTPException(status_code=400, detail=str(e))
112-
111+
113112
return {
114113
"user": user,
115114
"session": session
116115
}
117116

118-
119117
@router.get("/auth/token")
120118
async def get_token(request: Request, response: Response, auth_client: AuthClient = Depends(get_auth_client)):
121119
# Prepare store_options with the Request object (used by the state store to read cookies)
122120
store_options = {"request": request, "response": response}
123121
try:
124122
# Retrieve access token from the client
125123
access_token = await auth_client.client.get_access_token(store_options=store_options)
126-
124+
127125
return {
128126
"access_token_available": bool(access_token),
129127
"access_token_preview": access_token[:10] + "..." if access_token else None,
@@ -132,33 +130,32 @@ async def get_token(request: Request, response: Response, auth_client: AuthClien
132130
except Exception as e:
133131
raise HTTPException(status_code=400, detail=str(e))
134132

135-
136133
@router.get("/auth/connection/{connection_name}")
137134
async def get_connection_token(
138135
connection_name: str,
139-
request: Request,
140-
response: Response,
136+
request: Request,
137+
response: Response,
141138
auth_client: AuthClient = Depends(get_auth_client),
142139
login_hint: Optional[str] = None
143140
):
144141
store_options = {"request": request, "response": response}
145-
142+
146143
try:
147144
# Create connection options as a dictionary
148145
connection_options = {
149146
"connection": connection_name
150147
}
151-
148+
152149
# Add login_hint if provided
153150
if login_hint:
154151
connection_options["login_hint"] = login_hint
155-
152+
156153
# Retrieve connection-specific access token
157154
access_token = await auth_client.client.get_access_token_for_connection(
158-
connection_options,
155+
connection_options,
159156
store_options=store_options
160157
)
161-
158+
162159
# Return a response with token information
163160
return {
164161
"connection": connection_name,
@@ -170,34 +167,37 @@ async def get_connection_token(
170167
# Handle all errors with a single exception handler
171168
raise HTTPException(status_code=400, detail=str(e))
172169
#################### ********Testing Routes End ****** ###################################
173-
170+
174171
if config.mount_connect_routes:
175172

176173
@router.get("/auth/connect")
177-
async def connect(request: Request, response: Response,
178-
connection: Optional[str] = Query(None),
179-
connectionScope: Optional[str] = Query(None),
180-
returnTo: Optional[str] = Query(None),
181-
auth_client: AuthClient = Depends(get_auth_client)):
174+
async def connect(request: Request, response: Response,
175+
connection: Optional[str] = Query(None),
176+
connectionScope: Optional[str] = Query(None),
177+
returnTo: Optional[str] = Query(None),
178+
auth_client: AuthClient = Depends(get_auth_client)):
182179

183180
# Extract query parameters (connection, connectionScope, returnTo)
184181
connection = connection or request.query_params.get("connection")
185-
connection_scope = connectionScope or request.query_params.get("connectionScope")
186-
dangerous_return_to = returnTo or request.query_params.get("returnTo")
187-
182+
connection_scope = connectionScope or request.query_params.get(
183+
"connectionScope")
184+
dangerous_return_to = returnTo or request.query_params.get(
185+
"returnTo")
188186

189187
if not connection:
190188
raise HTTPException(
191189
status_code=400,
192190
detail="connection is not set"
193191
)
194-
195-
sanitized_return_to = to_safe_redirect(dangerous_return_to or "/", auth_client.config.app_base_url)
196-
192+
193+
sanitized_return_to = to_safe_redirect(
194+
dangerous_return_to or "/", auth_client.config.app_base_url)
195+
197196
# Create the callback URL for linking
198197
callback_path = "/auth/connect/callback"
199-
redirect_uri = create_route_url(callback_path, auth_client.config.app_base_url)
200-
198+
redirect_uri = create_route_url(
199+
callback_path, auth_client.config.app_base_url)
200+
201201
# Call the startLinkUser method on our AuthClient. This method should accept parameters similar to:
202202
# connection, connectionScope, authorizationParams (with redirect_uri), and app_state.
203203
link_user_url = await auth_client.start_link_user({
@@ -211,9 +211,7 @@ async def connect(request: Request, response: Response,
211211
}
212212
}, store_options={"request": request, "response": response})
213213

214-
redirect_response = RedirectResponse(url=link_user_url)
215-
216-
return merge_set_cookie_headers(response, redirect_response)
214+
return RedirectResponse(url=link_user_url, headers=response.headers)
217215

218216
@router.get("/auth/connect/callback")
219217
async def connect_callback(request: Request, response: Response, auth_client: AuthClient = Depends(get_auth_client)):
@@ -223,43 +221,40 @@ async def connect_callback(request: Request, response: Response, auth_client: Au
223221
result = await auth_client.complete_link_user(callback_url, store_options={"request": request, "response": response})
224222
except Exception as e:
225223
raise HTTPException(status_code=400, detail=str(e))
226-
224+
227225
# Retrieve the returnTo parameter from app_state if available
228226
return_to = result.get("app_state", {}).get("returnTo")
229227

230228
app_base_url = auth_client.config.app_base_url
231229

232-
redirect_response = RedirectResponse(url=return_to or app_base_url)
233-
if "set-cookie" in response.headers:
234-
cookies = response.headers.getlist("set-cookie") if hasattr(response.headers, "getlist") else [response.headers["set-cookie"]]
235-
for cookie in cookies:
236-
redirect_response.headers.append("set-cookie", cookie)
237-
return redirect_response
238-
230+
return RedirectResponse(url=return_to or app_base_url, headers=response.headers)
231+
239232
@router.get("/auth/unconnect")
240-
async def connect(request: Request, response: Response,
241-
connection: Optional[str] = Query(None),
242-
connectionScope: Optional[str] = Query(None),
243-
returnTo: Optional[str] = Query(None),
244-
auth_client: AuthClient = Depends(get_auth_client)):
233+
async def connect(request: Request, response: Response,
234+
connection: Optional[str] = Query(None),
235+
connectionScope: Optional[str] = Query(None),
236+
returnTo: Optional[str] = Query(None),
237+
auth_client: AuthClient = Depends(get_auth_client)):
245238

246239
# Extract query parameters (connection, connectionScope, returnTo)
247240
connection = connection or request.query_params.get("connection")
248-
dangerous_return_to = returnTo or request.query_params.get("returnTo")
249-
241+
dangerous_return_to = returnTo or request.query_params.get(
242+
"returnTo")
250243

251244
if not connection:
252245
raise HTTPException(
253246
status_code=400,
254247
detail="connection is not set"
255248
)
256-
257-
sanitized_return_to = to_safe_redirect(dangerous_return_to or "/", auth_client.config.app_base_url)
258-
249+
250+
sanitized_return_to = to_safe_redirect(
251+
dangerous_return_to or "/", auth_client.config.app_base_url)
252+
259253
# Create the callback URL for linking
260254
callback_path = "/auth/unconnect/callback"
261-
redirect_uri = create_route_url(callback_path, auth_client.config.app_base_url)
262-
255+
redirect_uri = create_route_url(
256+
callback_path, auth_client.config.app_base_url)
257+
263258
# Call the startLinkUser method on our AuthClient. This method should accept parameters similar to:
264259
# connection, connectionScope, authorizationParams (with redirect_uri), and app_state.
265260
link_user_url = await auth_client.start_unlink_user({
@@ -272,9 +267,7 @@ async def connect(request: Request, response: Response,
272267
}
273268
}, store_options={"request": request, "response": response})
274269

275-
redirect_response = RedirectResponse(url=link_user_url)
276-
277-
return merge_set_cookie_headers(response, redirect_response)
270+
return RedirectResponse(url=link_user_url, headers=response.headers)
278271

279272
@router.get("/auth/unconnect/callback")
280273
async def unconnect_callback(request: Request, response: Response, auth_client: AuthClient = Depends(get_auth_client)):
@@ -284,13 +277,10 @@ async def unconnect_callback(request: Request, response: Response, auth_client:
284277
result = await auth_client.complete_unlink_user(callback_url, store_options={"request": request, "response": response})
285278
except Exception as e:
286279
raise HTTPException(status_code=400, detail=str(e))
287-
280+
288281
# Retrieve the returnTo parameter from appState if available
289282
return_to = result.get("app_state", {}).get("returnTo")
290283

291284
app_base_url = auth_client.config.app_base_url
292285

293-
redirect_response = RedirectResponse(url=return_to or app_base_url)
294-
295-
return merge_set_cookie_headers(response, redirect_response)
296-
286+
return RedirectResponse(url=return_to or app_base_url, headers=response.headers)

0 commit comments

Comments
 (0)