Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/auth0_fastapi/auth/auth_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@

# Imported from auth0-server-python
from auth0_server_python.auth_server.server_client import ServerClient
from auth0_server_python.auth_types import LogoutOptions, StartInteractiveLoginOptions
from auth0_server_python.auth_types import (
CompleteConnectAccountResponse,
ConnectAccountOptions,
LogoutOptions,
StartInteractiveLoginOptions,
)
from fastapi import HTTPException, Request, Response, status

from auth0_fastapi.config import Auth0Config
Expand Down Expand Up @@ -45,6 +50,7 @@ def __init__(
transaction_store=transaction_store,
state_store=state_store,
pushed_authorization_requests=config.pushed_authorization_requests,
use_mrrt=config.use_mrrt,
authorization_params={
"audience": config.audience,
"redirect_uri": redirect_uri,
Expand Down Expand Up @@ -82,6 +88,36 @@ async def complete_login(
"""
return await self.client.complete_interactive_login(callback_url, store_options=store_options)

async def start_connect_account(
self,
connection: str,
app_state: dict = None,
authorization_params: dict = None,
store_options: dict = None,
) -> str:
"""
Initiates the connected account process.
Optionally, an app_state dictionary can be passed to persist additional state.
Returns the connect URL to redirect the user.
"""
options = ConnectAccountOptions(
connection=connection,
app_state=app_state,
authorization_params=authorization_params
)
return await self.client.start_connect_account(options=options, store_options=store_options)

async def complete_connect_account(
self,
url: str,
store_options: dict = None,
) -> CompleteConnectAccountResponse:
"""
Completes the connect account process using the callback URL.
Returns the completed connect account response.
"""
return await self.client.complete_connect_account(url, store_options=store_options)

async def logout(
self,
return_to: str = None,
Expand Down
2 changes: 2 additions & 0 deletions src/auth0_fastapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ class Auth0Config(BaseModel):
audience: Optional[str] = Field(None, description="Target audience for tokens (if applicable)")
authorization_params: Optional[dict[str, Any]] = Field(None, description="Additional parameters to include in the authorization request")
pushed_authorization_requests: bool = Field(False, description="Whether to use pushed authorization requests")
use_mrrt: bool = Field(False, description="Whether to use Multi-Resource Refresh Tokens (MRRT)")
# Route-mounting flags with desired defaults
mount_routes: bool = Field(True, description="Controls /auth/* routes: login, logout, callback, backchannel-logout")
mount_connect_routes: bool = Field(False, description="Controls /auth/connect routes (account-linking)")
mount_connected_account_routes: bool = Field(False, description="Controls /auth/connect-account routes (for connected accounts)")
#Cookie Settings
cookie_name: str = Field("_a0_session", description="Name of the cookie storing session data")
session_expiration: int = Field(259200, description="Session expiration time in seconds (default: 3 days)")
Expand Down
47 changes: 40 additions & 7 deletions src/auth0_fastapi/server/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,35 @@ async def callback(
):
"""
Endpoint to handle the callback after Auth0 authentication.
Processes the callback URL and completes the login flow.
Processes the callback URL and completes the login or connected account flow.
Redirects the user to a post-login URL based on appState or a default.
"""
full_callback_url = str(request.url)

try:
session_data = await auth_client.complete_login(
full_callback_url,
store_options={"request": request, "response": response},
)
if "connect_code" in request.query_params and config.mount_connected_account_routes:
connect_complete_response = await auth_client.complete_connect_account(
full_callback_url, store_options={"request": request, "response": response})

app_state = connect_complete_response.app_state or {}
else:
session_data = await auth_client.complete_login(
full_callback_url, store_options={"request": request, "response": response})

# Extract the returnTo URL from the appState if available.
app_state = session_data.get("app_state", {})
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))


# Extract the returnTo URL from the appState if available.
return_to = session_data.get("app_state", {}).get("returnTo")
return_to = app_state.get("returnTo")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the fault of this PR, but this SDK should have something like this https://github.com/auth0/nextjs-auth0/blob/main/EXAMPLES.md#oncallback-hook - otherwise the api response gets lost

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh I struggled with what to do after completion. I thought this was the best option for expediency to get something functional for the GA date as there's not actually a lot thats useful in the response. This SDK is pretty bare feature wise though compared to the JS SDK. I can add something like this in in a follow up PR


# Assuming config is stored on app.state
default_redirect = auth_client.config.app_base_url

return RedirectResponse(url=return_to or default_redirect, headers=response.headers)
safe_redirect = to_safe_redirect(return_to or default_redirect, auth_client.config.app_base_url)
return RedirectResponse(url=safe_redirect, headers=response.headers)

@router.get("/auth/logout")
async def logout(
Expand Down Expand Up @@ -123,7 +133,30 @@ async def backchannel_logout(
raise HTTPException(status_code=400, detail=str(e))
return Response(status_code=204)

if config.mount_connected_account_routes:
@router.get("/auth/connect-account")
async def connect_account(
request: Request,
response: Response,
connection: str = Query(),
auth_client: AuthClient = Depends(get_auth_client),
):
"""
Endpoint to initiate the connect account flow for linking a third-party account to the user's profile.
Redirects the user to the Auth0 connect account URL.
"""
authorization_params = {
k: v for k, v in request.query_params.items() if k not in ["connection", "returnTo"]}

return_to = request.query_params.get("returnTo")
connect_account_url = await auth_client.start_connect_account(
connection=connection,
app_state={"returnTo": return_to} if return_to else None,
authorization_params=authorization_params,
store_options={"request": request, "response": response},
)

return RedirectResponse(url=connect_account_url, headers=response.headers)

if config.mount_connect_routes:

Expand Down
46 changes: 46 additions & 0 deletions src/auth0_fastapi/test/test_auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import AsyncMock, Mock, patch

import pytest
from auth0_server_python.auth_types import CompleteConnectAccountResponse, ConnectAccountOptions
from fastapi import HTTPException, Request, Response

from auth0_fastapi.auth.auth_client import AuthClient
Expand Down Expand Up @@ -392,3 +393,48 @@ async def test_store_options_validation(self, auth_client):
await auth_client.start_login(store_options=valid_options)

mock_start.assert_called()


class TestConnectedAccountFlow:
"""Test connected account functionality."""

@pytest.mark.asyncio
async def test_start_connect_account(self, auth_client):
"""Test initiating user account linking."""
mock_connect_url = "https://test.auth0.com/connected-accounts/connect?ticket"

with patch.object(auth_client.client, 'start_connect_account', new_callable=AsyncMock) as mock_start_connect:
mock_start_connect.return_value = mock_connect_url

result = await auth_client.start_connect_account(
connection="google-oauth2",
app_state={"returnTo": "/profile"},
authorization_params={"prompt": "consent"},
)

assert result == mock_connect_url
mock_start_connect.assert_called_once_with(
options=ConnectAccountOptions(
connection="google-oauth2",
app_state={"returnTo": "/profile"},
authorization_params={"prompt": "consent"},
), store_options=None)

@pytest.mark.asyncio
async def test_complete_connect_account(self, auth_client):
"""Test initiating user account linking."""
mock_callback_url = "https://test.auth0.com/connected-accounts/connect?ticket"
mock_result = CompleteConnectAccountResponse(
id="id_12345",
connection="google-oauth2",
access_type="offline",
scopes=["read:foo"],
created_at="1970-01-01T00:00:00Z"
)
with patch.object(auth_client.client, 'complete_connect_account', new_callable=AsyncMock) as mock_complete:
mock_complete.return_value = mock_result

result = await auth_client.complete_connect_account(mock_callback_url)

assert result == mock_result
mock_complete.assert_called_once_with(mock_callback_url, store_options=None)
Loading