Skip to content

Commit 918fb51

Browse files
Invitation acceptance flow
1 parent 8a16159 commit 918fb51

File tree

9 files changed

+445
-41
lines changed

9 files changed

+445
-41
lines changed

exceptions/http_exceptions.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,31 @@ def __init__(self):
184184
super().__init__(
185185
status_code=500, # Internal Server Error seems appropriate
186186
detail="Failed to send invitation email. Please try again later or contact support."
187+
)
188+
189+
190+
class InvalidInvitationTokenError(HTTPException):
191+
"""Raised when an invitation token is invalid, expired, or not found."""
192+
def __init__(self):
193+
super().__init__(
194+
status_code=404,
195+
detail="Invitation not found or expired"
196+
)
197+
198+
199+
class InvitationEmailMismatchError(HTTPException):
200+
"""Raised when a user attempts to accept an invitation sent to a different email address."""
201+
def __init__(self):
202+
super().__init__(
203+
status_code=403,
204+
detail="This invitation was sent to a different email address"
205+
)
206+
207+
208+
class InvitationProcessingError(HTTPException):
209+
"""Raised when an error occurs during the processing of a valid invitation."""
210+
def __init__(self, detail: str = "Failed to process invitation. Please try again later."):
211+
super().__init__(
212+
status_code=500, # Internal Server Error
213+
detail=detail
187214
)

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from fastapi.templating import Jinja2Templates
88
from fastapi.exceptions import RequestValidationError
99
from starlette.exceptions import HTTPException as StarletteHTTPException
10-
from routers import account, dashboard, organization, role, user, static_pages, invitations
10+
from routers import account, dashboard, organization, role, user, static_pages, invitation
1111
from utils.dependencies import (
1212
get_optional_user
1313
)
@@ -46,7 +46,7 @@ async def lifespan(app: FastAPI):
4646

4747
app.include_router(account.router)
4848
app.include_router(dashboard.router)
49-
app.include_router(invitations.router)
49+
app.include_router(invitation.router)
5050
app.include_router(organization.router)
5151
app.include_router(role.router)
5252
app.include_router(static_pages.router)

routers/account.py

Lines changed: 156 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
from logging import getLogger
33
from typing import Optional, Tuple
44
from urllib.parse import urlparse
5-
from fastapi import APIRouter, Depends, BackgroundTasks, Form, Request
5+
from fastapi import APIRouter, Depends, BackgroundTasks, Form, Request, Query
66
from fastapi.responses import RedirectResponse
77
from fastapi.templating import Jinja2Templates
8+
from fastapi.exceptions import HTTPException
89
from starlette.datastructures import URLPath
910
from pydantic import EmailStr
1011
from sqlmodel import Session, select
11-
from utils.models import User, DataIntegrityError, Account
12+
from utils.models import User, DataIntegrityError, Account, Invitation
1213
from utils.db import get_session
1314
from utils.auth import (
1415
HTML_PASSWORD_PATTERN,
@@ -32,10 +33,15 @@
3233
from exceptions.http_exceptions import (
3334
EmailAlreadyRegisteredError,
3435
CredentialsError,
35-
PasswordValidationError
36+
PasswordValidationError,
37+
InvalidInvitationTokenError,
38+
InvitationEmailMismatchError,
39+
InvitationProcessingError
3640
)
3741
from routers.dashboard import router as dashboard_router
3842
from routers.user import router as user_router
43+
from routers.organization import router as org_router
44+
from utils.invitations import process_invitation
3945
logger = getLogger("uvicorn.error")
4046

4147
router = APIRouter(prefix="/account", tags=["account"])
@@ -97,7 +103,8 @@ def logout():
97103
async def read_login(
98104
request: Request,
99105
user: Optional[User] = Depends(get_optional_user),
100-
email_updated: Optional[str] = "false"
106+
email_updated: Optional[str] = Query("false"),
107+
invitation_token: Optional[str] = Query(None)
101108
):
102109
"""
103110
Render login page or redirect to dashboard if already logged in.
@@ -106,14 +113,21 @@ async def read_login(
106113
return RedirectResponse(url=dashboard_router.url_path_for("read_dashboard"), status_code=302)
107114
return templates.TemplateResponse(
108115
"account/login.html",
109-
{"request": request, "user": user, "email_updated": email_updated}
116+
{
117+
"request": request,
118+
"user": user,
119+
"email_updated": email_updated,
120+
"invitation_token": invitation_token
121+
}
110122
)
111123

112124

113125
@router.get("/register")
114126
async def read_register(
115127
request: Request,
116-
user: Optional[User] = Depends(get_optional_user)
128+
user: Optional[User] = Depends(get_optional_user),
129+
email: Optional[EmailStr] = Query(None),
130+
invitation_token: Optional[str] = Query(None)
117131
):
118132
"""
119133
Render registration page or redirect to dashboard if already logged in.
@@ -123,7 +137,13 @@ async def read_register(
123137

124138
return templates.TemplateResponse(
125139
"account/register.html",
126-
{"request": request, "user": user, "password_pattern": HTML_PASSWORD_PATTERN}
140+
{
141+
"request": request,
142+
"user": user,
143+
"password_pattern": HTML_PASSWORD_PATTERN,
144+
"email": email,
145+
"invitation_token": invitation_token
146+
}
127147
)
128148

129149

@@ -204,38 +224,98 @@ async def register(
204224
email: EmailStr = Form(...),
205225
session: Session = Depends(get_session),
206226
_: None = Depends(validate_password_strength_and_match),
207-
password: str = Form(...)
227+
password: str = Form(...),
228+
invitation_token: Optional[str] = Form(None)
208229
) -> RedirectResponse:
209230
"""
210-
Register a new user account.
231+
Register a new user account, optionally processing an invitation.
211232
"""
212233
# Check if the email is already registered
213-
account: Optional[Account] = session.exec(select(Account).where(
234+
existing_account: Optional[Account] = session.exec(select(Account).where(
214235
Account.email == email)).one_or_none()
215236

216-
if account:
237+
if existing_account:
217238
raise EmailAlreadyRegisteredError()
218239

219240
# Hash the password
220241
hashed_password = get_password_hash(password)
221242

222-
# Create the account
243+
# Create the account and user instances (don't commit yet)
223244
account = Account(email=email, hashed_password=hashed_password)
224245
session.add(account)
225-
session.flush() # Flush to get the account ID
226-
227-
# Create the user
228-
account.user = User(name=name, account_id=account.id)
229-
session.add(account)
230-
session.commit()
246+
session.flush() # Flush here to get account.id before creating User
247+
248+
# Ensure account has an ID after flush
249+
if not account.id:
250+
logger.error(f"Account ID not generated after flush for email {email}. Aborting registration.")
251+
session.rollback() # Rollback the account add
252+
raise DataIntegrityError(resource="Account ID generation")
253+
254+
new_user = User(name=name, account_id=account.id) # Use account.id
255+
session.add(new_user)
256+
257+
# Default redirect target
258+
redirect_url = dashboard_router.url_path_for("read_dashboard")
259+
260+
# Process invitation if token is provided (BEFORE final commit)
261+
if invitation_token:
262+
logger.info(f"Registration attempt with invitation token: {invitation_token} for email {email}")
263+
# Fetch the invitation
264+
statement = select(Invitation).where(Invitation.token == invitation_token)
265+
invitation = session.exec(statement).first()
266+
267+
if not invitation or not invitation.is_active():
268+
logger.warning(f"Invalid or inactive invitation token provided during registration: {invitation_token}")
269+
# Consider raising a more generic error to avoid exposing token validity
270+
raise InvalidInvitationTokenError()
271+
272+
# Verify email matches
273+
if email != invitation.invitee_email:
274+
logger.warning(
275+
f"Invitation email mismatch for token {invitation_token} during registration. "
276+
f"Account: {email}, Invitation: {invitation.invitee_email}"
277+
)
278+
# Consider raising a more generic error to avoid confirming email existence
279+
raise InvitationEmailMismatchError()
280+
281+
# Process the invitation (adds changes to the session)
282+
try:
283+
logger.info(f"Processing invitation {invitation.id} for new user {new_user.name} ({email}) during registration.")
284+
process_invitation(invitation, new_user, session)
285+
# Set redirect to the organization page
286+
redirect_url = org_router.url_path_for("read_organization", org_id=invitation.organization_id)
287+
logger.info(f"Redirecting new user {new_user.name} to organization {invitation.organization_id} after accepting invitation {invitation.id}.")
288+
except Exception as e:
289+
logger.error(
290+
f"Error processing invitation {invitation.id} for new user {new_user.name} ({email}) during registration: {e}",
291+
exc_info=True
292+
)
293+
session.rollback()
294+
raise InvitationProcessingError()
295+
296+
else:
297+
logger.info(f"Standard registration for email {email}. Redirecting to dashboard.")
298+
299+
# Commit all changes (Account, User, potentially Invitation)
300+
try:
301+
session.commit()
302+
except Exception as e:
303+
logger.error(f"Error committing transaction during registration for {email}: {e}", exc_info=True)
304+
session.rollback()
305+
# Use DataIntegrityError for commit failure
306+
raise DataIntegrityError(resource="Account/User registration")
307+
308+
# Refresh the account to ensure all relationships (like user) are loaded after commit
231309
session.refresh(account)
310+
# We might need the user object refreshed too if process_invitation modified it directly
311+
# session.refresh(new_user) # Let's assume process_invitation only modifies the invitation object for now
232312

233-
# Create access token
234-
access_token = create_access_token(data={"sub": email})
235-
refresh_token = create_refresh_token(data={"sub": email})
313+
# Create access token using the committed account's email
314+
access_token = create_access_token(data={"sub": account.email, "fresh": True})
315+
refresh_token = create_refresh_token(data={"sub": account.email})
236316

237317
# Set cookie
238-
response = RedirectResponse(url=dashboard_router.url_path_for("read_dashboard"), status_code=303)
318+
response = RedirectResponse(url=str(redirect_url), status_code=303) # Use determined redirect_url
239319
response.set_cookie(
240320
key="access_token",
241321
value=access_token,
@@ -256,21 +336,72 @@ async def register(
256336

257337
@router.post("/login", response_class=RedirectResponse)
258338
async def login(
259-
account_and_session: Tuple[Account, Session] = Depends(get_account_from_credentials)
339+
account_and_session: Tuple[Account, Session] = Depends(get_account_from_credentials),
340+
invitation_token: Optional[str] = Form(None)
260341
) -> RedirectResponse:
261342
"""
262-
Log in a user with valid credentials.
343+
Log in a user with valid credentials and process invitation if token is provided.
263344
"""
264345
account, session = account_and_session
265346

347+
# Default redirect target
348+
redirect_url = dashboard_router.url_path_for("read_dashboard")
349+
350+
if invitation_token:
351+
logger.info(f"Login attempt with invitation token: {invitation_token} for account {account.email}")
352+
# Fetch the invitation
353+
statement = select(Invitation).where(Invitation.token == invitation_token)
354+
invitation = session.exec(statement).first()
355+
356+
if not invitation or not invitation.is_active():
357+
logger.warning(f"Invalid or inactive invitation token provided during login: {invitation_token}")
358+
raise InvalidInvitationTokenError()
359+
360+
# Verify email matches
361+
if account.email != invitation.invitee_email:
362+
logger.warning(
363+
f"Invitation email mismatch for token {invitation_token}. "
364+
f"Account: {account.email}, Invitation: {invitation.invitee_email}"
365+
)
366+
raise InvitationEmailMismatchError()
367+
368+
# Ensure user relationship is loaded for process_invitation
369+
if not account.user:
370+
logger.debug(f"Refreshing user relationship for account {account.id}")
371+
session.refresh(account, attribute_names=["user"])
372+
if not account.user:
373+
# This should not happen if the account has a valid user relationship
374+
logger.error(f"Failed to load user for account {account.id} during invitation processing.")
375+
raise DataIntegrityError(resource="User relation")
376+
377+
# Process the invitation
378+
try:
379+
logger.info(f"Processing invitation {invitation.id} for user {account.user.id} during login.")
380+
process_invitation(invitation, account.user, session)
381+
session.commit()
382+
# Set redirect to the organization page
383+
redirect_url = org_router.url_path_for("read_organization", org_id=invitation.organization_id)
384+
logger.info(f"Redirecting user {account.user.id} to organization {invitation.organization_id} after accepting invitation {invitation.id}.")
385+
except Exception as e:
386+
logger.error(
387+
f"Error processing invitation {invitation.id} for user {account.user.id} during login: {e}",
388+
exc_info=True
389+
)
390+
session.rollback()
391+
# Raise the specific invitation processing error
392+
raise InvitationProcessingError()
393+
394+
else:
395+
logger.info(f"Standard login for account {account.email}. Redirecting to dashboard.")
396+
266397
# Create access token
267398
access_token = create_access_token(
268399
data={"sub": account.email, "fresh": True}
269400
)
270401
refresh_token = create_refresh_token(data={"sub": account.email})
271402

272403
# Set cookie
273-
response = RedirectResponse(url=dashboard_router.url_path_for("read_dashboard"), status_code=303)
404+
response = RedirectResponse(url=str(redirect_url), status_code=303)
274405
response.set_cookie(
275406
key="access_token",
276407
value=access_token,

0 commit comments

Comments
 (0)